From 9deb8c6e330e46b0616f4d84e8901e449ebb79a6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 13 Mar 2022 11:20:31 -0500 Subject: [PATCH 01/81] Move pandas groupby outside of event loop --- distributed/shuffle/shuffle_extension.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 8f13480b91d..2c1548adfa3 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -79,12 +79,12 @@ def receive(self, output_partition: int, data: pd.DataFrame) -> None: assert not self.transferred, "`receive` called after barrier task" self.output_partitions[output_partition].append(data) - async def add_partition(self, data: pd.DataFrame) -> None: + async def add_partition(self, data: pd.DataFrame, groups: list) -> None: assert not self.transferred, "`add_partition` called after barrier task" tasks = [] # NOTE: `groupby` blocks the event loop, but it also holds the GIL, # so we don't bother offloading to a thread. See bpo-7946. - for output_partition, data in data.groupby(self.metadata.column): + for output_partition, data in groups: # NOTE: `column` must refer to an integer column, which is the output partition number for the row. # This is always `_partitions`, added by `dask/dataframe/shuffle.py::shuffle`. addr = self.metadata.worker_for(int(output_partition)) @@ -241,9 +241,13 @@ async def _create_shuffle( return metadata # NOTE: unused in tasks, just handy for tests def add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: - sync(self.worker.loop, self._add_partition, data, shuffle_id) + column = self._get_shuffle(shuffle_id).metadata.column + groups = list(data.groupby(column)) + sync(self.worker.loop, self._add_partition, data, shuffle_id, groups) - async def _add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: + async def _add_partition( + self, data: pd.DataFrame, shuffle_id: ShuffleId, groups: list + ) -> None: """ Task: Hand off an input partition to the ShuffleExtension. @@ -251,7 +255,7 @@ async def _add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> Non Using an unknown ``shuffle_id`` is an error. """ - await self._get_shuffle(shuffle_id).add_partition(data) + await self._get_shuffle(shuffle_id).add_partition(data, groups) def barrier(self, shuffle_id: ShuffleId) -> None: sync(self.worker.loop, self._barrier, shuffle_id) From 3f629113eaf97b6f90f08e8a48878646537e71eb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 13 Mar 2022 12:14:17 -0500 Subject: [PATCH 02/81] Add MultiFile prototype --- distributed/multi_file.py | 65 ++++++++++++++++++++++++++++ distributed/tests/test_multi_file.py | 52 ++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 distributed/multi_file.py create mode 100644 distributed/tests/test_multi_file.py diff --git a/distributed/multi_file.py b/distributed/multi_file.py new file mode 100644 index 00000000000..2b207d1165f --- /dev/null +++ b/distributed/multi_file.py @@ -0,0 +1,65 @@ +import pathlib +import pickle +import shutil + +from zict import LRU + +from dask.utils import parse_bytes + +from .system import MEMORY_LIMIT + + +class MultiFile: + def __init__( + self, + directory: pathlib.Path, + dump=pickle.dump, + load=pickle.load, + join=None, + n_files=256, + memory_limit=MEMORY_LIMIT / 2, + ): + assert join + self.directory = directory + self.dump = dump + self.load = load + self.join = join + + self.file_buffer_size = int(parse_bytes(memory_limit) / n_files) + + self.file_cache = LRU(n_files, dict(), on_evict=lambda k, v: v.close()) + + def open_file(self, id: str): + try: + return self.file_cache[id] + except KeyError: + file = open( + self.directory / id, mode="ab+", buffering=self.file_buffer_size + ) + self.file_cache[id] = file + return file + + def read(self, id): + parts = [] + file = self.open_file(id) + file.seek(0) + while True: + try: + parts.append(self.load(file)) + except EOFError: + break + return self.join(parts) + + def write(self, part, id): + file = self.open_file(id) + self.dump(part, file) + + def close(self): + shutil.rmtree(self.directory) + self.file_cache.clear() + + def __enter__(self): + return self + + def __exit__(self, exc, typ, traceback): + self.close() diff --git a/distributed/tests/test_multi_file.py b/distributed/tests/test_multi_file.py new file mode 100644 index 00000000000..0acc64e2a98 --- /dev/null +++ b/distributed/tests/test_multi_file.py @@ -0,0 +1,52 @@ +import os +import random + +import numpy as np +import pandas as pd +import pytest + +from distributed.multi_file import MultiFile + + +def test_basic(tmp_path): + with MultiFile( + directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat + ) as mf: + df = pd.DataFrame({"x": np.arange(1000), "y": np.arange(1000) * 2}) + mf.write(df, "a") + mf.write(df, "b") + mf.write(df * 2, "a") + + a = mf.read("a") + b = mf.read("b") + + assert (df == b).all().all() + assert (pd.concat([df, df * 2]) == a).all().all() + + assert not os.path.exists(tmp_path) + + +@pytest.mark.parametrize("count", [2, 100, 1000]) +def test_many(tmp_path, count): + with MultiFile( + directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat + ) as mf: + df = pd.DataFrame({"x": np.arange(10), "y": np.arange(10) * 2}) + + L = list(range(count)) + + random.shuffle(L) + for i in L: + mf.write(df + i, str(i)) + + random.shuffle(L) + for i in L: + mf.write(df * i, str(i)) + + random.shuffle(L) + for i in L: + out = mf.read(str(i)) + + assert (pd.concat([df + i, df * i]) == out).all().all() + + assert not os.path.exists(tmp_path) From 626eda00efce4a2bdbd526f2c27e9b7f63a94a45 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 13 Mar 2022 14:26:34 -0500 Subject: [PATCH 03/81] Integrate MultiFile with shuffle extension --- distributed/multi_file.py | 43 +++++++++++++++++------- distributed/shuffle/shuffle_extension.py | 32 ++++++++++++------ distributed/tests/test_multi_file.py | 6 ++-- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/distributed/multi_file.py b/distributed/multi_file.py index 2b207d1165f..470b3245a29 100644 --- a/distributed/multi_file.py +++ b/distributed/multi_file.py @@ -1,8 +1,10 @@ +import os import pathlib import pickle import shutil +import threading -from zict import LRU +import zict from dask.utils import parse_bytes @@ -12,46 +14,61 @@ class MultiFile: def __init__( self, - directory: pathlib.Path, + directory, dump=pickle.dump, load=pickle.load, join=None, n_files=256, memory_limit=MEMORY_LIMIT / 2, + file_cache=None, ): - assert join - self.directory = directory + if not join: + import pandas as pd + + join = pd.concat + self.directory = pathlib.Path(directory) + if not os.path.exists(self.directory): + os.mkdir(self.directory) self.dump = dump self.load = load self.join = join + self.lock = threading.Lock() self.file_buffer_size = int(parse_bytes(memory_limit) / n_files) - self.file_cache = LRU(n_files, dict(), on_evict=lambda k, v: v.close()) + if file_cache is None: + file_cache = zict.LRU(n_files, dict(), on_evict=lambda k, v: v.close()) + self.file_cache = file_cache def open_file(self, id: str): - try: - return self.file_cache[id] - except KeyError: - file = open( - self.directory / id, mode="ab+", buffering=self.file_buffer_size - ) - self.file_cache[id] = file - return file + with self.lock: + try: + return self.file_cache[id] + except KeyError: + file = open( + self.directory / str(id), + mode="ab+", + buffering=self.file_buffer_size, + ) + self.file_cache[id] = file + return file def read(self, id): parts = [] file = self.open_file(id) file.seek(0) + # TODO: Note that this is unsafe to multiple threads trying to read the same file while True: try: parts.append(self.load(file)) except EOFError: break + # TODO: We could consider deleting the file at this point return self.join(parts) def write(self, part, id): file = self.open_file(id) + # TODO: We should consider offloading this to a separate thread self.dump(part, file) def close(self): diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 2c1548adfa3..9cab6e0ecb8 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -2,10 +2,13 @@ import asyncio import math -from collections import defaultdict +import os from dataclasses import dataclass from typing import TYPE_CHECKING, NewType +import zict + +from distributed.multi_file import MultiFile from distributed.protocol import to_serialize from distributed.utils import sync @@ -68,16 +71,25 @@ def npartitions_for(self, worker: str) -> int: class Shuffle: "State for a single active shuffle" - def __init__(self, metadata: ShuffleMetadata, worker: Worker) -> None: + def __init__( + self, metadata: ShuffleMetadata, worker: Worker, file_cache=None + ) -> None: self.metadata = metadata self.worker = worker - self.output_partitions: defaultdict[int, list[pd.DataFrame]] = defaultdict(list) + + self.multi_file = MultiFile( + directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), + memory_limit="1 GiB", # TODO: lift this up to the global ShuffleExtension + file_cache=file_cache, + n_files=min(256, metadata.npartitions), + ) + self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False def receive(self, output_partition: int, data: pd.DataFrame) -> None: assert not self.transferred, "`receive` called after barrier task" - self.output_partitions[output_partition].append(data) + self.multi_file.write(data, output_partition) async def add_partition(self, data: pd.DataFrame, groups: list) -> None: assert not self.transferred, "`add_partition` called after barrier task" @@ -103,8 +115,6 @@ async def add_partition(self, data: pd.DataFrame, groups: list) -> None: await asyncio.gather(*tasks) def get_output_partition(self, i: int) -> pd.DataFrame: - import pandas as pd - assert self.transferred, "`get_output_partition` called before barrier task" assert self.metadata.worker_for(i) == self.worker.address, ( @@ -120,13 +130,10 @@ def get_output_partition(self, i: int) -> pd.DataFrame: self.output_partitions_left -= 1 try: - parts = self.output_partitions.pop(i) + return self.multi_file.read(i) except KeyError: return self.metadata.empty - assert parts, f"Empty entry for output partition {i}" - return pd.concat(parts, copy=False) - def inputs_done(self) -> None: assert not self.transferred, "`inputs_done` called multiple times" self.transferred = True @@ -148,6 +155,7 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} + self.file_cache = zict.LRU(256, dict(), on_evict=lambda k, v: v.close()) # Handlers ########## @@ -162,7 +170,9 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: raise ValueError( f"Shuffle {metadata.id!r} is already registered on worker {self.worker.address}" ) - self.shuffles[metadata.id] = Shuffle(metadata, self.worker) + self.shuffles[metadata.id] = Shuffle( + metadata, self.worker, file_cache=self.file_cache + ) def shuffle_receive( self, diff --git a/distributed/tests/test_multi_file.py b/distributed/tests/test_multi_file.py index 0acc64e2a98..0628d8c0a6a 100644 --- a/distributed/tests/test_multi_file.py +++ b/distributed/tests/test_multi_file.py @@ -37,15 +37,15 @@ def test_many(tmp_path, count): random.shuffle(L) for i in L: - mf.write(df + i, str(i)) + mf.write(df + i, i) random.shuffle(L) for i in L: - mf.write(df * i, str(i)) + mf.write(df * i, i) random.shuffle(L) for i in L: - out = mf.read(str(i)) + out = mf.read(i) assert (pd.concat([df + i, df * i]) == out).all().all() From 7f48f778c5bb20b75fecd660f9f04025a92ce046 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 13 Mar 2022 19:58:10 -0500 Subject: [PATCH 04/81] Add buffered comms --- distributed/multi_comm.py | 99 +++++++++++++++++++++++ distributed/multi_file.py | 12 ++- distributed/shuffle/shuffle_extension.py | 80 ++++++++---------- distributed/shuffle/tests/test_shuffle.py | 14 ++++ 4 files changed, 155 insertions(+), 50 deletions(-) create mode 100644 distributed/multi_comm.py create mode 100644 distributed/shuffle/tests/test_shuffle.py diff --git a/distributed/multi_comm.py b/distributed/multi_comm.py new file mode 100644 index 00000000000..d3621f3f569 --- /dev/null +++ b/distributed/multi_comm.py @@ -0,0 +1,99 @@ +import asyncio +import threading +from collections import defaultdict + +from dask.utils import parse_bytes + +from .core import rpc +from .protocol import to_serialize +from .sizeof import sizeof +from .system import MEMORY_LIMIT +from .utils import log_errors, offload + + +class MultiComm: + def __init__( + self, + memory_limit=MEMORY_LIMIT / 4, + join=None, + rpc=rpc, + sizeof=sizeof, + max_connections=10, + shuffle_id=None, + ): + self.lock = threading.Lock() + self.shards = defaultdict(list) + self.sizes = defaultdict(int) + self.total_size = 0 + self.memory_limit = parse_bytes(memory_limit) + self.thread_condition = threading.Condition() + if join is None: + import pandas as pd + + join = pd.concat + self.join = join + self.max_connections = max_connections + self.sizeof = sizeof + self.shuffle_id = shuffle_id + self._futures = set() + self._done = False + self.rpc = rpc + + def put(self, data: dict): + with self.lock: + for address, shard in data.items(): + size = self.sizeof(shard) + self.shards[address].append(shard) + self.sizes[address] += size + self.total_size += size + + del data + + while self.total_size > self.memory_limit: + with self.thread_condition: + self.thread_condition.wait(0.500) # Block until memory calms down + + async def communicate(self): + self.comm_queue = asyncio.Queue(maxsize=self.max_connections) + for _ in range(self.max_connections): + self.comm_queue.put_nowait(None) + + while not self._done: + if not self.shards: + await asyncio.sleep(0.1) + continue + + await self.comm_queue.get() + + with self.lock: + address = max(self.sizes, key=self.sizes.get) + shards = self.shards.pop(address) + size = self.sizes.pop(address) + + future = asyncio.ensure_future(self.process(address, shards, size)) + del shards + self._futures.add(future) + with self.thread_condition: + self.thread_condition.notify() + + async def process(self, address: str, shards: list, size: int): + with log_errors(): + shards = await offload(self.join, shards) + # Consider boosting total_size a bit here to account for duplication + + try: + await self.rpc(address).shuffle_receive( + data=to_serialize(shards), + shuffle_id=self.shuffle_id, + ) + finally: + self.total_size -= size + await self.comm_queue.put(None) + + async def flush(self): + while self.shards: + await asyncio.sleep(0.05) + + await asyncio.gather(*self._futures) + assert not self.total_size + self._done = True diff --git a/distributed/multi_file.py b/distributed/multi_file.py index 470b3245a29..ba2d50f4458 100644 --- a/distributed/multi_file.py +++ b/distributed/multi_file.py @@ -9,6 +9,7 @@ from dask.utils import parse_bytes from .system import MEMORY_LIMIT +from .utils import offload class MultiFile: @@ -64,12 +65,15 @@ def read(self, id): except EOFError: break # TODO: We could consider deleting the file at this point - return self.join(parts) + if parts: + return self.join(parts) + else: + raise KeyError(id) - def write(self, part, id): - file = self.open_file(id) + async def write(self, part, id): + file = await offload(self.open_file, id) # TODO: We should consider offloading this to a separate thread - self.dump(part, file) + await offload(self.dump, part, file) def close(self): shutil.rmtree(self.directory) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 9cab6e0ecb8..af71f6c664c 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -8,9 +8,10 @@ import zict +from distributed.multi_comm import MultiComm from distributed.multi_file import MultiFile from distributed.protocol import to_serialize -from distributed.utils import sync +from distributed.utils import offload, sync if TYPE_CHECKING: import pandas as pd @@ -79,40 +80,39 @@ def __init__( self.multi_file = MultiFile( directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - memory_limit="1 GiB", # TODO: lift this up to the global ShuffleExtension + memory_limit="200 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, n_files=min(256, metadata.npartitions), ) + self.multi_comm = MultiComm( + memory_limit="200 MiB", # TODO + rpc=worker.rpc, + shuffle_id=self.metadata.id, + # sizeof= # TODO, something smarter + ) + self.worker.loop.add_callback(self.multi_comm.communicate) self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False - def receive(self, output_partition: int, data: pd.DataFrame) -> None: + async def receive(self, data: pd.DataFrame) -> None: assert not self.transferred, "`receive` called after barrier task" - self.multi_file.write(data, output_partition) - - async def add_partition(self, data: pd.DataFrame, groups: list) -> None: - assert not self.transferred, "`add_partition` called after barrier task" - tasks = [] - # NOTE: `groupby` blocks the event loop, but it also holds the GIL, - # so we don't bother offloading to a thread. See bpo-7946. - for output_partition, data in groups: - # NOTE: `column` must refer to an integer column, which is the output partition number for the row. - # This is always `_partitions`, added by `dask/dataframe/shuffle.py::shuffle`. - addr = self.metadata.worker_for(int(output_partition)) - task = asyncio.create_task( - self.worker.rpc(addr).shuffle_receive( - shuffle_id=self.metadata.id, - output_partition=output_partition, - data=to_serialize(data), - ) - ) - tasks.append(task) + groups = await offload(lambda: list(data.groupby(self.metadata.column))) + for output_partition, shard in groups: + await self.multi_file.write(shard, output_partition) + + def add_partition(self, data: pd.DataFrame) -> None: - # TODO Once RerunGroup logic exists (https://github.com/dask/distributed/issues/5403), - # handle errors and cancellation here in a way that lets other workers cancel & clean up their shuffles. - # Without it, letting errors kill the task is all we can do. - await asyncio.gather(*tasks) + grouper = ( + len(self.metadata.workers) + * data[self.metadata.column] + // self.metadata.npartitions + ) # .astype(data[self.metadata.column].dtype) + + groups = list(data.groupby(grouper)) + out = {self.metadata.workers[int(i)]: shard for i, shard in groups} + assert len(data) == sum(len(df) for _, df in out.items()) + self.multi_comm.put(out) def get_output_partition(self, i: int) -> pd.DataFrame: assert self.transferred, "`get_output_partition` called before barrier task" @@ -130,7 +130,8 @@ def get_output_partition(self, i: int) -> pd.DataFrame: self.output_partitions_left -= 1 try: - return self.multi_file.read(i) + df = self.multi_file.read(i) + return df except KeyError: return self.metadata.empty @@ -174,25 +175,26 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: metadata, self.worker, file_cache=self.file_cache ) - def shuffle_receive( + async def shuffle_receive( self, comm: object, shuffle_id: ShuffleId, - output_partition: int, data: pd.DataFrame, ) -> None: """ Hander: Receive an incoming shard of data from a peer worker. Using an unknown ``shuffle_id`` is an error. """ - self._get_shuffle(shuffle_id).receive(output_partition, data) + await self._get_shuffle(shuffle_id).receive(data) - def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: + async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: """ Hander: Inform the extension that all input partitions have been handed off to extensions. Using an unknown ``shuffle_id`` is an error. """ shuffle = self._get_shuffle(shuffle_id) + await shuffle.multi_comm.flush() + await asyncio.sleep(1) # TODO shuffle.inputs_done() if shuffle.done(): # If the shuffle has no output partitions, remove it now; @@ -251,21 +253,7 @@ async def _create_shuffle( return metadata # NOTE: unused in tasks, just handy for tests def add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: - column = self._get_shuffle(shuffle_id).metadata.column - groups = list(data.groupby(column)) - sync(self.worker.loop, self._add_partition, data, shuffle_id, groups) - - async def _add_partition( - self, data: pd.DataFrame, shuffle_id: ShuffleId, groups: list - ) -> None: - """ - Task: Hand off an input partition to the ShuffleExtension. - - This will block until the extension is ready to receive another input partition. - - Using an unknown ``shuffle_id`` is an error. - """ - await self._get_shuffle(shuffle_id).add_partition(data, groups) + self._get_shuffle(shuffle_id).add_partition(data=data) def barrier(self, shuffle_id: ShuffleId) -> None: sync(self.worker.loop, self._barrier, shuffle_id) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py new file mode 100644 index 00000000000..7f2dee71a0e --- /dev/null +++ b/distributed/shuffle/tests/test_shuffle.py @@ -0,0 +1,14 @@ +import dask +import dask.dataframe as dd + +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True) +async def test_basic(c, s, a, b): + df = dask.datasets.timeseries() + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + a, b = c.compute([df.x.size, out.x.size]) + a = await a + b = await b + assert a == b From 99bb283f0087479a942e9bb33c13f1f28b2f12f2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 13 Mar 2022 20:09:58 -0500 Subject: [PATCH 05/81] Move multi files to shuffle/ --- distributed/{ => shuffle}/multi_comm.py | 10 +++++----- distributed/{ => shuffle}/multi_file.py | 4 ++-- distributed/shuffle/shuffle_extension.py | 4 ++-- .../{ => shuffle}/tests/test_multi_file.py | 18 ++++++++++-------- 4 files changed, 19 insertions(+), 17 deletions(-) rename distributed/{ => shuffle}/multi_comm.py (94%) rename distributed/{ => shuffle}/multi_file.py (97%) rename distributed/{ => shuffle}/tests/test_multi_file.py (74%) diff --git a/distributed/multi_comm.py b/distributed/shuffle/multi_comm.py similarity index 94% rename from distributed/multi_comm.py rename to distributed/shuffle/multi_comm.py index d3621f3f569..7c8f1eb32fc 100644 --- a/distributed/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -4,11 +4,11 @@ from dask.utils import parse_bytes -from .core import rpc -from .protocol import to_serialize -from .sizeof import sizeof -from .system import MEMORY_LIMIT -from .utils import log_errors, offload +from ..core import rpc +from ..protocol import to_serialize +from ..sizeof import sizeof +from ..system import MEMORY_LIMIT +from ..utils import log_errors, offload class MultiComm: diff --git a/distributed/multi_file.py b/distributed/shuffle/multi_file.py similarity index 97% rename from distributed/multi_file.py rename to distributed/shuffle/multi_file.py index ba2d50f4458..869bce3ee14 100644 --- a/distributed/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -8,8 +8,8 @@ from dask.utils import parse_bytes -from .system import MEMORY_LIMIT -from .utils import offload +from ..system import MEMORY_LIMIT +from ..utils import offload class MultiFile: diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index af71f6c664c..f5f4523c215 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -8,9 +8,9 @@ import zict -from distributed.multi_comm import MultiComm -from distributed.multi_file import MultiFile from distributed.protocol import to_serialize +from distributed.shuffle.multi_comm import MultiComm +from distributed.shuffle.multi_file import MultiFile from distributed.utils import offload, sync if TYPE_CHECKING: diff --git a/distributed/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py similarity index 74% rename from distributed/tests/test_multi_file.py rename to distributed/shuffle/tests/test_multi_file.py index 0628d8c0a6a..671a56aeaf5 100644 --- a/distributed/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -5,17 +5,18 @@ import pandas as pd import pytest -from distributed.multi_file import MultiFile +from distributed.shuffle.multi_file import MultiFile -def test_basic(tmp_path): +@pytest.mark.asyncio +async def test_basic(tmp_path): with MultiFile( directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat ) as mf: df = pd.DataFrame({"x": np.arange(1000), "y": np.arange(1000) * 2}) - mf.write(df, "a") - mf.write(df, "b") - mf.write(df * 2, "a") + await mf.write(df, "a") + await mf.write(df, "b") + await mf.write(df * 2, "a") a = mf.read("a") b = mf.read("b") @@ -26,8 +27,9 @@ def test_basic(tmp_path): assert not os.path.exists(tmp_path) +@pytest.mark.asyncio @pytest.mark.parametrize("count", [2, 100, 1000]) -def test_many(tmp_path, count): +async def test_many(tmp_path, count): with MultiFile( directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat ) as mf: @@ -37,11 +39,11 @@ def test_many(tmp_path, count): random.shuffle(L) for i in L: - mf.write(df + i, i) + await mf.write(df + i, i) random.shuffle(L) for i in L: - mf.write(df * i, i) + await mf.write(df * i, i) random.shuffle(L) for i in L: From 0d49ab92df61373fd6d4af4ac0c39713c200db4c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 15 Mar 2022 10:59:41 -0500 Subject: [PATCH 06/81] add arrow Performance is good, still need to track down memory --- distributed/shuffle/multi_comm.py | 18 ++- distributed/shuffle/multi_file.py | 12 +- distributed/shuffle/shuffle.py | 10 +- distributed/shuffle/shuffle_extension.py | 142 +++++++++++++++--- .../shuffle/tests/test_shuffle_extension.py | 79 ++++++++++ 5 files changed, 230 insertions(+), 31 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 7c8f1eb32fc..faa1db4ec1c 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -8,7 +8,7 @@ from ..protocol import to_serialize from ..sizeof import sizeof from ..system import MEMORY_LIMIT -from ..utils import log_errors, offload +from ..utils import log_errors class MultiComm: @@ -25,12 +25,10 @@ def __init__( self.shards = defaultdict(list) self.sizes = defaultdict(int) self.total_size = 0 + self.total_moved = 0 self.memory_limit = parse_bytes(memory_limit) self.thread_condition = threading.Condition() - if join is None: - import pandas as pd - - join = pd.concat + assert join self.join = join self.max_connections = max_connections self.sizeof = sizeof @@ -46,12 +44,13 @@ def put(self, data: dict): self.shards[address].append(shard) self.sizes[address] += size self.total_size += size + self.total_moved += size del data while self.total_size > self.memory_limit: with self.thread_condition: - self.thread_condition.wait(0.500) # Block until memory calms down + self.thread_condition.wait(0.100) # Block until memory calms down async def communicate(self): self.comm_queue = asyncio.Queue(maxsize=self.max_connections) @@ -78,7 +77,9 @@ async def communicate(self): async def process(self, address: str, shards: list, size: int): with log_errors(): - shards = await offload(self.join, shards) + shards = self.join(shards) + # shards = await offload(self.join, shards) + # Consider boosting total_size a bit here to account for duplication try: @@ -96,4 +97,7 @@ async def flush(self): await asyncio.gather(*self._futures) assert not self.total_size + from dask.utils import format_bytes + + print("total moved", format_bytes(self.total_moved)) self._done = True diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 869bce3ee14..0f479f36501 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -6,6 +6,7 @@ import zict +from dask.sizeof import sizeof from dask.utils import parse_bytes from ..system import MEMORY_LIMIT @@ -22,11 +23,9 @@ def __init__( n_files=256, memory_limit=MEMORY_LIMIT / 2, file_cache=None, + sizeof=sizeof, ): - if not join: - import pandas as pd - - join = pd.concat + assert join self.directory = pathlib.Path(directory) if not os.path.exists(self.directory): os.mkdir(self.directory) @@ -40,6 +39,8 @@ def __init__( if file_cache is None: file_cache = zict.LRU(n_files, dict(), on_evict=lambda k, v: v.close()) self.file_cache = file_cache + self.bytes_written = 0 + self.bytes_read = 0 def open_file(self, id: str): with self.lock: @@ -66,6 +67,8 @@ def read(self, id): break # TODO: We could consider deleting the file at this point if parts: + for part in parts: + self.bytes_read += sizeof(part) return self.join(parts) else: raise KeyError(id) @@ -73,6 +76,7 @@ def read(self, id): async def write(self, part, id): file = await offload(self.open_file, id) # TODO: We should consider offloading this to a separate thread + self.bytes_written += sizeof(part) await offload(self.dump, part, file) def close(self): diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/shuffle.py index 33fe1189a09..e1ef318a184 100644 --- a/distributed/shuffle/shuffle.py +++ b/distributed/shuffle/shuffle.py @@ -59,10 +59,18 @@ def rearrange_by_column_p2p( npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) + empty = df._meta.copy() + for c, dt in empty.dtypes.items(): + if dt == object: + empty[c] = empty[c].astype( + "string" + ) # TODO: we fail at non-string object dtypes + empty[column] = empty[column].astype("int64") # TODO: this shouldn't be necesssary + setup = delayed(shuffle_setup, pure=True)( NewShuffleMetadata( ShuffleId(token), - df._meta, + empty, column, npartitions, ) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index f5f4523c215..ee81a277e57 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -1,11 +1,13 @@ from __future__ import annotations import asyncio +import functools import math import os from dataclasses import dataclass from typing import TYPE_CHECKING, NewType +import toolz import zict from distributed.protocol import to_serialize @@ -15,6 +17,7 @@ if TYPE_CHECKING: import pandas as pd + import pyarrow as pa from distributed.worker import Worker @@ -53,8 +56,7 @@ def worker_for(self, output_partition: int) -> str: raise IndexError( f"Output partition {output_partition} does not exist in a shuffle producing {self.npartitions} partitions" ) - i = len(self.workers) * output_partition // self.npartitions - return self.workers[i] + return worker_for(output_partition, self.workers, self.npartitions) def _partition_range(self, worker: str) -> tuple[int, int]: "Get the output partition numbers (inclusive) that a worker will hold" @@ -78,40 +80,61 @@ def __init__( self.metadata = metadata self.worker = worker + import pyarrow as pa + self.multi_file = MultiFile( + dump=dump_arrow, + load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), memory_limit="200 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, n_files=min(256, metadata.npartitions), + join=pa.concat_tables, # pd.concat ) self.multi_comm = MultiComm( memory_limit="200 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, - # sizeof= # TODO, something smarter + sizeof=lambda L: sum(map(len, L)), + join=functools.partial(sum, start=[]), ) self.worker.loop.add_callback(self.multi_comm.communicate) self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False + self.total_recvd = 0 - async def receive(self, data: pd.DataFrame) -> None: + async def receive(self, data: list[pa.Buffer]) -> None: assert not self.transferred, "`receive` called after barrier task" - groups = await offload(lambda: list(data.groupby(self.metadata.column))) - for output_partition, shard in groups: - await self.multi_file.write(shard, output_partition) + self.total_recvd += sum(map(len, data)) + # An ugly way of turning these batches back into an arrow table + import io - def add_partition(self, data: pd.DataFrame) -> None: + import pyarrow as pa - grouper = ( - len(self.metadata.workers) - * data[self.metadata.column] - // self.metadata.npartitions - ) # .astype(data[self.metadata.column].dtype) + bio = io.BytesIO() + bio.write(pa.Schema.from_pandas(self.metadata.empty).serialize()) + for batch in data: + bio.write(batch) + bio.seek(0) + sr = pa.RecordBatchStreamReader(bio) + data = sr.read_all() - groups = list(data.groupby(grouper)) - out = {self.metadata.workers[int(i)]: shard for i, shard in groups} - assert len(data) == sum(len(df) for _, df in out.items()) + groups = await offload(split_by_partition, data, self.metadata.column) + + assert len(data) == sum(map(len, groups.values())) + for output_partition, shard in groups.items(): + await self.multi_file.write(shard, output_partition) + + def add_partition(self, data: pd.DataFrame) -> None: + out = split_by_worker( + data, self.metadata.column, self.metadata.npartitions, self.metadata.workers + ) + assert len(data) == sum(map(len, out.values())) + out = { + k: [b.serialize().to_pybytes() for b in t.to_batches()] + for k, t in out.items() + } self.multi_comm.put(out) def get_output_partition(self, i: int) -> pd.DataFrame: @@ -131,9 +154,9 @@ def get_output_partition(self, i: int) -> pd.DataFrame: try: df = self.multi_file.read(i) - return df + return df.to_pandas() except KeyError: - return self.metadata.empty + return self.metadata.empty.head(0) def inputs_done(self) -> None: assert not self.transferred, "`inputs_done` called multiple times" @@ -319,8 +342,8 @@ def get_output_partition( shuffle = self._get_shuffle(shuffle_id) output = shuffle.get_output_partition(output_partition) if shuffle.done(): - # key missing if another thread got to it first self.shuffles.pop(shuffle_id, None) + # key missing if another thread got to it first return output def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: @@ -331,3 +354,84 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: raise ValueError( f"Shuffle {shuffle_id!r} is not registered on worker {self.worker.address}" ) from None + + +def split_by_worker( + df: pd.DataFrame, column: str, npartitions: int, workers: list[str] +) -> dict: + """ + Split data into many arrow batches, partitioned by destination worker + """ + import numpy as np + import pyarrow as pa + + grouper = (len(workers) * df[column] // npartitions).astype(df[column].dtype).values + + t = pa.Table.from_pandas(df) + del df + t = t.add_column(len(t.columns), "_worker", [grouper]) + t = t.sort_by("_worker") + + worker = np.asarray(t.select(["_worker"]))[0] + t = t.drop(["_worker"]) + splits = np.where(worker[1:] != worker[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + + w = np.unique(grouper) + w.sort() + + return {workers[w]: shard for w, shard in zip(w, shards)} + + +def split_by_partition( + t: pa.Table, + column: str, +) -> dict: + """ + Split data into many arrow batches, partitioned by destination worker + """ + import numpy as np + + partitions = np.unique(np.asarray(t.select([column]))[0]) + partitions.sort() + t = t.sort_by(column) + + partition = np.asarray(t.select([column]))[0] + splits = np.where(partition[1:] != partition[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + assert len(t) == sum(map(len, shards)) + assert len(partitions) == len(shards) + return dict(zip(partitions, shards)) + + +def dump_arrow(t: pa.Table, file): + if file.tell() == 0: + file.write(t.schema.serialize()) + for batch in t.to_batches(): + file.write(batch.serialize()) + + +def load_arrow(file): + import pyarrow as pa + + try: + sr = pa.RecordBatchStreamReader(file) + return sr.read_all() + except Exception: + raise EOFError + + +def worker_for(output_partition: int, workers: list[str], npartitions: int) -> str: + "Get the address of the worker which should hold this output partition number" + i = len(workers) * output_partition // npartitions + return workers[i] diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py index ad536280672..efdaeb376cf 100644 --- a/distributed/shuffle/tests/test_shuffle_extension.py +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -17,6 +17,11 @@ ShuffleId, ShuffleMetadata, ShuffleWorkerExtension, + dump_arrow, + load_arrow, + split_by_partition, + split_by_worker, + worker_for, ) if TYPE_CHECKING: @@ -286,3 +291,77 @@ async def test_get_partition(c: Client, s: Scheduler, *workers: Worker): assert not ext.shuffles with pytest.raises(ValueError, match="not registered"): ext.get_output_partition(metadata.id, 0) + + +def test_split_by_worker(): + df = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "_partition": [0, 1, 2, 0, 1], + } + ) + workers = ["alice", "bob"] + npartitions = 3 + + out = split_by_worker(df, "_partition", npartitions, workers) + assert set(out) == {"alice", "bob"} + assert out["alice"].column_names == list(df.columns) + + assert sum(map(len, out.values())) == len(df) + + +def test_split_by_worker_many_workers(): + df = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "_partition": [5, 7, 5, 0, 1], + } + ) + workers = ["a", "b", "c", "d", "e", "f", "g", "h"] + npartitions = 10 + + out = split_by_worker(df, "_partition", npartitions, workers) + assert worker_for(5, workers, npartitions) in out + assert worker_for(0, workers, npartitions) in out + assert worker_for(7, workers, npartitions) in out + assert worker_for(1, workers, npartitions) in out + + assert sum(map(len, out.values())) == len(df) + + +def test_split_by_partition(): + import pyarrow as pa + + df = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "_partition": [3, 1, 2, 3, 1], + } + ) + t = pa.Table.from_pandas(df) + + out = split_by_partition(t, "_partition") + assert set(out) == {1, 2, 3} + assert out[1].column_names == list(df.columns) + assert sum(map(len, out.values())) == len(df) + + +def test_load_dump_arrow(tmp_path): + import pyarrow as pa + + df = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "_partition": [3, 1, 2, 3, 1], + } + ) + t = pa.Table.from_pandas(df) + with open(tmp_path / "foo", mode="wb") as f: + dump_arrow(t, f) + dump_arrow(t, f) + dump_arrow(t, f) + + with open(tmp_path / "foo", mode="rb") as f: + tt = load_arrow(f) + + assert str(tt) == str(pa.concat_tables([t, t, t])) From 1e1311d8e8e0ce48ca2f755c43548276d3ea7d1e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 15 Mar 2022 15:33:58 -0500 Subject: [PATCH 07/81] Handle buffers manually in multi_file This manages memory more smoothly. We still have issues though in that we're still passing around slices of arrow tables, which hold onto large references --- distributed/shuffle/multi_comm.py | 23 +++- distributed/shuffle/multi_file.py | 134 ++++++++++++++++------ distributed/shuffle/shuffle_extension.py | 10 +- distributed/shuffle/tests/test_shuffle.py | 6 + 4 files changed, 132 insertions(+), 41 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index faa1db4ec1c..10b27911a5a 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -20,12 +20,14 @@ def __init__( sizeof=sizeof, max_connections=10, shuffle_id=None, + max_message_size="128 MiB", ): self.lock = threading.Lock() self.shards = defaultdict(list) self.sizes = defaultdict(int) self.total_size = 0 self.total_moved = 0 + self.max_message_size = parse_bytes(max_message_size) self.memory_limit = parse_bytes(memory_limit) self.thread_condition = threading.Condition() assert join @@ -66,14 +68,25 @@ async def communicate(self): with self.lock: address = max(self.sizes, key=self.sizes.get) - shards = self.shards.pop(address) - size = self.sizes.pop(address) + + size = 0 + shards = [] + while size < self.max_message_size: + try: + shard = self.shards[address].pop() + except IndexError: + del self.shards[address] + del self.sizes[address] + break + else: + shards.append(shard) + s = self.sizeof(shard) + size += s + self.sizes[address] -= s future = asyncio.ensure_future(self.process(address, shards, size)) del shards self._futures.add(future) - with self.thread_condition: - self.thread_condition.notify() async def process(self, address: str, shards: list, size: int): with log_errors(): @@ -89,6 +102,8 @@ async def process(self, address: str, shards: list, size: int): ) finally: self.total_size -= size + with self.thread_condition: + self.thread_condition.notify() await self.comm_queue.put(None) async def flush(self): diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 0f479f36501..86f4c85eca6 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -1,16 +1,15 @@ +import asyncio import os import pathlib import pickle import shutil -import threading - -import zict +from collections import defaultdict from dask.sizeof import sizeof from dask.utils import parse_bytes from ..system import MEMORY_LIMIT -from ..utils import offload +from ..utils import log_errors, offload class MultiFile: @@ -20,7 +19,7 @@ def __init__( dump=pickle.dump, load=pickle.load, join=None, - n_files=256, + concurrent_files=1, memory_limit=MEMORY_LIMIT / 2, file_cache=None, sizeof=sizeof, @@ -32,39 +31,104 @@ def __init__( self.dump = dump self.load = load self.join = join - self.lock = threading.Lock() + self.sizeof = sizeof + + self.shards = defaultdict(list) + self.sizes = defaultdict(int) + self.total_size = 0 + self.total_received = 0 - self.file_buffer_size = int(parse_bytes(memory_limit) / n_files) + self.memory_limit = parse_bytes(memory_limit) + self.concurrent_files = concurrent_files + self.condition = asyncio.Condition() - if file_cache is None: - file_cache = zict.LRU(n_files, dict(), on_evict=lambda k, v: v.close()) - self.file_cache = file_cache self.bytes_written = 0 self.bytes_read = 0 - def open_file(self, id: str): - with self.lock: - try: - return self.file_cache[id] - except KeyError: - file = open( - self.directory / str(id), - mode="ab+", - buffering=self.file_buffer_size, + self._done = False + self._futures = set() + + async def put(self, data: dict): + this_size = 0 + for id, shard in data.items(): + size = self.sizeof(shard) + self.shards[id].append(shard) + self.sizes[id] += size + self.total_size += size + self.total_received += size + this_size += size + + del data + + while self.total_size > self.memory_limit: + async with self.condition: + from dask.utils import format_bytes + + print( + "waiting", + format_bytes(self.total_size), + "this", + format_bytes(this_size), ) - self.file_cache[id] = file - return file + try: + await asyncio.wait_for( + self.condition.wait(), 1 + ) # Block until memory calms down + except asyncio.TimeoutError: + continue + + async def communicate(self): + with log_errors(): + self.queue = asyncio.Queue(maxsize=self.concurrent_files) + for _ in range(self.concurrent_files): + self.queue.put_nowait(None) + + while not self._done: + if not self.shards: + await asyncio.sleep(0.1) + continue + + await self.queue.get() + + id = max(self.sizes, key=self.sizes.get) + shards = self.shards.pop(id) + size = self.sizes.pop(id) + + future = asyncio.ensure_future(self.process(id, shards, size)) + del shards + self._futures.add(future) + async with self.condition: + self.condition.notify() + + async def process(self, id: str, shards: list, size: int): + with log_errors(): + # Consider boosting total_size a bit here to account for duplication + + def _(): + # TODO: offload + with open( + self.directory / str(id), mode="ab", buffering=100_000_000 + ) as f: + for shard in shards: + self.dump(shard, f) + + await offload(_) + + self.total_size -= size + async with self.condition: + self.condition.notify() + await self.queue.put(None) def read(self, id): parts = [] - file = self.open_file(id) - file.seek(0) - # TODO: Note that this is unsafe to multiple threads trying to read the same file - while True: - try: - parts.append(self.load(file)) - except EOFError: - break + + with open(self.directory / str(id), mode="rb", buffering=100_000_000) as f: + while True: + try: + parts.append(self.load(f)) + except EOFError: + break + # TODO: We could consider deleting the file at this point if parts: for part in parts: @@ -73,11 +137,13 @@ def read(self, id): else: raise KeyError(id) - async def write(self, part, id): - file = await offload(self.open_file, id) - # TODO: We should consider offloading this to a separate thread - self.bytes_written += sizeof(part) - await offload(self.dump, part, file) + async def flush(self): + while self.shards: + await asyncio.sleep(0.05) + + await asyncio.gather(*self._futures) + assert not self.total_size + self._done = True def close(self): shutil.rmtree(self.directory) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index ee81a277e57..dad871202f1 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -88,7 +88,7 @@ def __init__( directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), memory_limit="200 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, - n_files=min(256, metadata.npartitions), + concurrent_files=1, join=pa.concat_tables, # pd.concat ) self.multi_comm = MultiComm( @@ -99,6 +99,7 @@ def __init__( join=functools.partial(sum, start=[]), ) self.worker.loop.add_callback(self.multi_comm.communicate) + self.worker.loop.add_callback(self.multi_file.communicate) self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False @@ -106,6 +107,9 @@ def __init__( async def receive(self, data: list[pa.Buffer]) -> None: assert not self.transferred, "`receive` called after barrier task" + from dask.utils import format_bytes + + print("recved", format_bytes(sum(map(len, data)))) self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table import io @@ -123,8 +127,7 @@ async def receive(self, data: list[pa.Buffer]) -> None: groups = await offload(split_by_partition, data, self.metadata.column) assert len(data) == sum(map(len, groups.values())) - for output_partition, shard in groups.items(): - await self.multi_file.write(shard, output_partition) + await self.multi_file.put(groups) def add_partition(self, data: pd.DataFrame) -> None: out = split_by_worker( @@ -218,6 +221,7 @@ async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None shuffle = self._get_shuffle(shuffle_id) await shuffle.multi_comm.flush() await asyncio.sleep(1) # TODO + await shuffle.multi_file.flush() shuffle.inputs_done() if shuffle.done(): # If the shuffle has no output partitions, remove it now; diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 7f2dee71a0e..cf489abcf7d 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -7,6 +7,12 @@ @gen_cluster(client=True) async def test_basic(c, s, a, b): df = dask.datasets.timeseries() + # df = dask.datasets.timeseries( + # start="2000-01-01", + # end="2000-06-02", + # freq="100ms", + # dtypes={"x": int, "y": float, "a": int, "b": float}, + # ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") a, b = c.compute([df.x.size, out.x.size]) a = await a From b99329a091a2853a9bac412c0bf359c7747b2178 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 15 Mar 2022 17:14:06 -0500 Subject: [PATCH 08/81] Pass around only bytes This helps to reduce lots of extra unmanaged memory This flow pretty well right now. I'm finding that it's useful to blend between the disk and comm buffer sizes. The abstraction in multi_file and multi_comm are getting a little bit worn down (it would be awkward to shift back to pandas), but maybe that's ok. --- distributed/shuffle/multi_file.py | 20 +++++++++++++------- distributed/shuffle/shuffle_extension.py | 22 +++++++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 86f4c85eca6..4baf3de2c37 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -52,7 +52,7 @@ async def put(self, data: dict): this_size = 0 for id, shard in data.items(): size = self.sizeof(shard) - self.shards[id].append(shard) + self.shards[id].extend(shard) self.sizes[id] += size self.total_size += size self.total_received += size @@ -93,6 +93,9 @@ async def communicate(self): id = max(self.sizes, key=self.sizes.get) shards = self.shards.pop(id) size = self.sizes.pop(id) + from dask.utils import format_bytes + + print("Writing", format_bytes(size), "to disk") future = asyncio.ensure_future(self.process(id, shards, size)) del shards @@ -122,12 +125,15 @@ def _(): def read(self, id): parts = [] - with open(self.directory / str(id), mode="rb", buffering=100_000_000) as f: - while True: - try: - parts.append(self.load(f)) - except EOFError: - break + try: + with open(self.directory / str(id), mode="rb", buffering=100_000_000) as f: + while True: + try: + parts.append(self.load(f)) + except EOFError: + break + except FileNotFoundError: + raise KeyError(id) # TODO: We could consider deleting the file at this point if parts: diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index dad871202f1..b057b3e98a2 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -83,16 +83,19 @@ def __init__( import pyarrow as pa self.multi_file = MultiFile( - dump=dump_arrow, + dump=functools.partial( + dump_batch, schema=pa.Schema.from_pandas(self.metadata.empty) + ), load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - memory_limit="200 MiB", # TODO: lift this up to the global ShuffleExtension + memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, concurrent_files=1, join=pa.concat_tables, # pd.concat + sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="200 MiB", # TODO + memory_limit="50 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), @@ -127,6 +130,13 @@ async def receive(self, data: list[pa.Buffer]) -> None: groups = await offload(split_by_partition, data, self.metadata.column) assert len(data) == sum(map(len, groups.values())) + + groups = await offload( + lambda: { + k: [batch.serialize() for batch in v.to_batches()] + for k, v in groups.items() + } + ) # TODO: consider offloading await self.multi_file.put(groups) def add_partition(self, data: pd.DataFrame) -> None: @@ -418,6 +428,12 @@ def split_by_partition( return dict(zip(partitions, shards)) +def dump_batch(batch, file, schema=None): + if file.tell() == 0: + file.write(schema.serialize()) + file.write(batch) + + def dump_arrow(t: pa.Table, file): if file.tell() == 0: file.write(t.schema.serialize()) From 5781b7eaabeff6b8ada706bc561f48ef7d474343 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 06:34:26 -0500 Subject: [PATCH 09/81] Clean up a few extra copies --- distributed/shuffle/multi_comm.py | 12 +++++-- distributed/shuffle/multi_file.py | 4 +-- distributed/shuffle/shuffle_extension.py | 40 +++++++++++++++-------- distributed/shuffle/tests/test_shuffle.py | 12 +++---- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 10b27911a5a..994eee90d1a 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -48,11 +48,19 @@ def put(self, data: dict): self.total_size += size self.total_moved += size - del data + del data, shard + + from dask.utils import format_bytes while self.total_size > self.memory_limit: with self.thread_condition: - self.thread_condition.wait(0.100) # Block until memory calms down + print( + "waiting comm", + format_bytes(self.total_size), + "this", + format_bytes(size), + ) + self.thread_condition.wait(1) # Block until memory calms down async def communicate(self): self.comm_queue = asyncio.Queue(maxsize=self.max_connections) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 4baf3de2c37..21118ab2e0b 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -58,14 +58,14 @@ async def put(self, data: dict): self.total_received += size this_size += size - del data + del data, shard while self.total_size > self.memory_limit: async with self.condition: from dask.utils import format_bytes print( - "waiting", + "waiting disk", format_bytes(self.total_size), "this", format_bytes(this_size), diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index b057b3e98a2..904dcd7631a 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -88,18 +88,19 @@ def __init__( ), load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension + memory_limit="500 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, - concurrent_files=1, + concurrent_files=3, join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="50 MiB", # TODO + memory_limit="300 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), join=functools.partial(sum, start=[]), + max_connections=10, ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -110,26 +111,21 @@ def __init__( async def receive(self, data: list[pa.Buffer]) -> None: assert not self.transferred, "`receive` called after barrier task" + import pyarrow as pa + from dask.utils import format_bytes print("recved", format_bytes(sum(map(len, data)))) self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table - import io - - import pyarrow as pa - - bio = io.BytesIO() - bio.write(pa.Schema.from_pandas(self.metadata.empty).serialize()) - for batch in data: - bio.write(batch) - bio.seek(0) - sr = pa.RecordBatchStreamReader(bio) - data = sr.read_all() + data = list_of_buffers_to_table( + data, schema=pa.Schema.from_pandas(self.metadata.empty) + ) groups = await offload(split_by_partition, data, self.metadata.column) assert len(data) == sum(map(len, groups.values())) + del data groups = await offload( lambda: { @@ -455,3 +451,19 @@ def worker_for(output_partition: int, workers: list[str], npartitions: int) -> s "Get the address of the worker which should hold this output partition number" i = len(workers) * output_partition // npartitions return workers[i] + + +def list_of_buffers_to_table(data: list[pa.Buffer], schema: pa.Schema) -> pa.Table: + import io + + import pyarrow as pa + + bio = io.BytesIO() + bio.write(schema.serialize()) + for batch in data: + bio.write(batch) + bio.seek(0) + sr = pa.RecordBatchStreamReader(bio) + data = sr.read_all() + bio.close() + return data diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index cf489abcf7d..1f537f77e90 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -6,13 +6,13 @@ @gen_cluster(client=True) async def test_basic(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-06-02", + freq="100ms", + dtypes={"x": int, "y": float, "a": int, "b": float}, + ) df = dask.datasets.timeseries() - # df = dask.datasets.timeseries( - # start="2000-01-01", - # end="2000-06-02", - # freq="100ms", - # dtypes={"x": int, "y": float, "a": int, "b": float}, - # ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") a, b = c.compute([df.x.size, out.x.size]) a = await a From 0ce6e01e1ddc3553b6ac3834ba9d07e0dd610830 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 06:48:08 -0500 Subject: [PATCH 10/81] Let comms continue without blocking on disk Isn't solid yet --- distributed/shuffle/multi_file.py | 8 +++++++- distributed/shuffle/shuffle_extension.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 21118ab2e0b..4680d7f0a4d 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -95,7 +95,13 @@ async def communicate(self): size = self.sizes.pop(id) from dask.utils import format_bytes - print("Writing", format_bytes(size), "to disk") + print( + "Writing", + format_bytes(size), + "to disk", + format_bytes(self.total_size), + "left", + ) future = asyncio.ensure_future(self.process(id, shards, size)) del shards diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 904dcd7631a..10c7f70ed6b 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -211,13 +211,25 @@ async def shuffle_receive( self, comm: object, shuffle_id: ShuffleId, - data: pd.DataFrame, + data: list[pa.Buffer], ) -> None: """ Hander: Receive an incoming shard of data from a peer worker. Using an unknown ``shuffle_id`` is an error. """ - await self._get_shuffle(shuffle_id).receive(data) + shuffle = self._get_shuffle(shuffle_id) + await shuffle.receive(data) + return + # TODO: it would be good to not have comms wait on disk if not + # necessary + future = asyncio.ensure_future(shuffle.receive(data)) + await future # backpressure + if ( + shuffle.multi_file.total_size + sum(map(len, data)) + > shuffle.multi_file.memory_limit + ): + return + await future # backpressure async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: """ From 3204997a0c3f0da1d3c4c25758790afec27f1ee9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 07:07:38 -0500 Subject: [PATCH 11/81] Move flush into multi_file.read This avoids a race --- distributed/shuffle/multi_comm.py | 3 +++ distributed/shuffle/multi_file.py | 4 ++++ distributed/shuffle/shuffle_extension.py | 3 +-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 994eee90d1a..f7675044527 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -119,8 +119,11 @@ async def flush(self): await asyncio.sleep(0.05) await asyncio.gather(*self._futures) + self._futures.clear() + assert not self.total_size from dask.utils import format_bytes print("total moved", format_bytes(self.total_moved)) + self._done = True diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 4680d7f0a4d..17373dcf5fb 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -154,7 +154,11 @@ async def flush(self): await asyncio.sleep(0.05) await asyncio.gather(*self._futures) + if all(future.done() for future in self._futures): + self._futures.clear() + assert not self.total_size + self._done = True def close(self): diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 10c7f70ed6b..fc4de9cee51 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -161,6 +161,7 @@ def get_output_partition(self, i: int) -> pd.DataFrame: ), f"No outputs remaining, but requested output partition {i} on {self.worker.address}." self.output_partitions_left -= 1 + sync(self.worker.loop, self.multi_file.flush) # type: ignore try: df = self.multi_file.read(i) return df.to_pandas() @@ -238,8 +239,6 @@ async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None """ shuffle = self._get_shuffle(shuffle_id) await shuffle.multi_comm.flush() - await asyncio.sleep(1) # TODO - await shuffle.multi_file.flush() shuffle.inputs_done() if shuffle.done(): # If the shuffle has no output partitions, remove it now; From 8b11d6dbccae87fb59aade3a57c52b5a36ba7c4e Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 07:09:47 -0500 Subject: [PATCH 12/81] Avoid multiple accesses to the same file --- distributed/shuffle/multi_file.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 17373dcf5fb..539edeeafde 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -47,6 +47,7 @@ def __init__( self._done = False self._futures = set() + self.active = set() async def put(self, data: dict): this_size = 0 @@ -112,6 +113,10 @@ async def communicate(self): async def process(self, id: str, shards: list, size: int): with log_errors(): # Consider boosting total_size a bit here to account for duplication + while id in self.active: + await asyncio.sleep(0.01) + + self.active.add(id) def _(): # TODO: offload @@ -123,6 +128,7 @@ def _(): await offload(_) + self.active.remove(id) self.total_size -= size async with self.condition: self.condition.notify() From 62cc43d0120aae116b8423380a3b365edf96da28 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 07:44:53 -0500 Subject: [PATCH 13/81] Change configuration for smoother single-machine use We don't need a lot of comm buffer, we also don't want more connecitons than machines (too much sitting in buffers). We also improve some printing --- distributed/shuffle/multi_comm.py | 25 ++++++++++++++++++------ distributed/shuffle/multi_file.py | 22 +++++++++++++-------- distributed/shuffle/shuffle_extension.py | 15 +++++++------- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index f7675044527..efdfbc608cb 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -52,14 +52,16 @@ def put(self, data: dict): from dask.utils import format_bytes + if self.total_size > self.memory_limit: + print( + "waiting comm", + format_bytes(self.total_size), + "this", + format_bytes(size), + ) + while self.total_size > self.memory_limit: with self.thread_condition: - print( - "waiting comm", - format_bytes(self.total_size), - "this", - format_bytes(size), - ) self.thread_condition.wait(1) # Block until memory calms down async def communicate(self): @@ -67,6 +69,8 @@ async def communicate(self): for _ in range(self.max_connections): self.comm_queue.put_nowait(None) + from dask.utils import format_bytes + while not self._done: if not self.shards: await asyncio.sleep(0.1) @@ -92,6 +96,15 @@ async def communicate(self): size += s self.sizes[address] -= s + print( + "Sending", + format_bytes(size), + "to comm", + format_bytes(self.total_size), + "left in ", + len(self.shards), + "buckets", + ) future = asyncio.ensure_future(self.process(address, shards, size)) del shards self._futures.add(future) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 539edeeafde..3d28c9e3165 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -61,16 +61,19 @@ async def put(self, data: dict): del data, shard + from dask.utils import format_bytes + + if self.total_size > self.memory_limit: + print( + "waiting disk", + format_bytes(self.total_size), + "this", + format_bytes(this_size), + ) + while self.total_size > self.memory_limit: async with self.condition: - from dask.utils import format_bytes - print( - "waiting disk", - format_bytes(self.total_size), - "this", - format_bytes(this_size), - ) try: await asyncio.wait_for( self.condition.wait(), 1 @@ -101,7 +104,9 @@ async def communicate(self): format_bytes(size), "to disk", format_bytes(self.total_size), - "left", + "left in", + len(self.shards), + "buckets", ) future = asyncio.ensure_future(self.process(id, shards, size)) @@ -125,6 +130,7 @@ def _(): ) as f: for shard in shards: self.dump(shard, f) + os.fsync(f) # TODO: maybe? await offload(_) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index fc4de9cee51..24290e16253 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -88,19 +88,19 @@ def __init__( ), load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - memory_limit="500 MiB", # TODO: lift this up to the global ShuffleExtension + memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, concurrent_files=3, join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="300 MiB", # TODO + memory_limit="50 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), join=functools.partial(sum, start=[]), - max_connections=10, + max_connections=3, ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -118,8 +118,10 @@ async def receive(self, data: list[pa.Buffer]) -> None: print("recved", format_bytes(sum(map(len, data)))) self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table - data = list_of_buffers_to_table( - data, schema=pa.Schema.from_pandas(self.metadata.empty) + data = await offload( + list_of_buffers_to_table, + data, + schema=pa.Schema.from_pandas(self.metadata.empty), ) groups = await offload(split_by_partition, data, self.metadata.column) @@ -224,12 +226,11 @@ async def shuffle_receive( # TODO: it would be good to not have comms wait on disk if not # necessary future = asyncio.ensure_future(shuffle.receive(data)) - await future # backpressure + # await future # backpressure if ( shuffle.multi_file.total_size + sum(map(len, data)) > shuffle.multi_file.memory_limit ): - return await future # backpressure async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: From 5a642485bf57042e457eda30a4b8a6df2cc3a4d3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 12:31:16 -0500 Subject: [PATCH 14/81] Fix up some concurrency issues --- distributed/shuffle/multi_comm.py | 40 +++++++++++++---------- distributed/shuffle/shuffle_extension.py | 9 +++-- distributed/shuffle/tests/test_shuffle.py | 4 +-- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index efdfbc608cb..db323d39228 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -20,7 +20,7 @@ def __init__( sizeof=sizeof, max_connections=10, shuffle_id=None, - max_message_size="128 MiB", + max_message_size="10 MiB", ): self.lock = threading.Lock() self.shards = defaultdict(list) @@ -87,27 +87,31 @@ async def communicate(self): try: shard = self.shards[address].pop() except IndexError: - del self.shards[address] - del self.sizes[address] break - else: + finally: shards.append(shard) s = self.sizeof(shard) size += s self.sizes[address] -= s - - print( - "Sending", - format_bytes(size), - "to comm", - format_bytes(self.total_size), - "left in ", - len(self.shards), - "buckets", - ) - future = asyncio.ensure_future(self.process(address, shards, size)) - del shards - self._futures.add(future) + if not self.shards[address]: + del self.shards[address] + assert not self.sizes[address] + del self.sizes[address] + + assert set(self.sizes) == set(self.shards) + assert shards + print( + "Sending", + format_bytes(size), + "to comm", + format_bytes(self.total_size), + "left in ", + len(self.shards), + "buckets", + ) + future = asyncio.ensure_future(self.process(address, shards, size)) + del shards + self._futures.add(future) async def process(self, address: str, shards: list, size: int): with log_errors(): @@ -134,6 +138,8 @@ async def flush(self): await asyncio.gather(*self._futures) self._futures.clear() + if self.total_size: + breakpoint() assert not self.total_size from dask.utils import format_bytes diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 24290e16253..f5ec8ee4a2c 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -90,7 +90,7 @@ def __init__( directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension file_cache=file_cache, - concurrent_files=3, + concurrent_files=2, join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), ) @@ -110,7 +110,10 @@ def __init__( self.total_recvd = 0 async def receive(self, data: list[pa.Buffer]) -> None: - assert not self.transferred, "`receive` called after barrier task" + # This is actually ok. Our local barrier might have finished, + # but barriers on other workers might still be running and sending us + # data + # assert not self.transferred, "`receive` called after barrier task" import pyarrow as pa from dask.utils import format_bytes @@ -432,6 +435,8 @@ def split_by_partition( ] shards.append(t.slice(offset=splits[-1], length=None)) assert len(t) == sum(map(len, shards)) + if len(partitions) != len(shards): + breakpoint() assert len(partitions) == len(shards) return dict(zip(partitions, shards)) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 1f537f77e90..e7f24b3de6d 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -4,7 +4,7 @@ from distributed.utils_test import gen_cluster -@gen_cluster(client=True) +@gen_cluster(client=True, timeout=1000000) async def test_basic(c, s, a, b): df = dask.datasets.timeseries( start="2000-01-01", @@ -12,7 +12,7 @@ async def test_basic(c, s, a, b): freq="100ms", dtypes={"x": int, "y": float, "a": int, "b": float}, ) - df = dask.datasets.timeseries() + # df = dask.datasets.timeseries() out = dd.shuffle.shuffle(df, "x", shuffle="p2p") a, b = c.compute([df.x.size, out.x.size]) a = await a From b90161353e0581348f8c0fd5343a497c86e564f2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 16:20:04 -0500 Subject: [PATCH 15/81] Fix shard size accountiing --- distributed/shuffle/multi_comm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index db323d39228..a4d8281cbfc 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -86,13 +86,13 @@ async def communicate(self): while size < self.max_message_size: try: shard = self.shards[address].pop() - except IndexError: - break - finally: shards.append(shard) s = self.sizeof(shard) size += s self.sizes[address] -= s + except IndexError: + break + finally: if not self.shards[address]: del self.shards[address] assert not self.sizes[address] From 24092bb99c9182c3e983d499d188153f76ed4a81 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 16 Mar 2022 16:22:02 -0500 Subject: [PATCH 16/81] add more connections if more workers --- distributed/shuffle/shuffle_extension.py | 4 ++-- distributed/shuffle/tests/test_shuffle.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index f5ec8ee4a2c..a3c93b6c2a8 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -95,12 +95,12 @@ def __init__( sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="50 MiB", # TODO + memory_limit="200 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), join=functools.partial(sum, start=[]), - max_connections=3, + max_connections=min((len(self.metadata.workers) - 1) or 1, 10), ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index e7f24b3de6d..ee03620f96d 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -12,7 +12,7 @@ async def test_basic(c, s, a, b): freq="100ms", dtypes={"x": int, "y": float, "a": int, "b": float}, ) - # df = dask.datasets.timeseries() + df = dask.datasets.timeseries() out = dd.shuffle.shuffle(df, "x", shuffle="p2p") a, b = c.compute([df.x.size, out.x.size]) a = await a From 7fbe4aa77c0d0b8a3ab3e8d5a180fe443840a8b2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 12:01:07 -0500 Subject: [PATCH 17/81] Allow worker extensions to piggy-back on heartbeat To enable better diagnostics, it would be useful to allow worker extensions to piggy-back on the standard heartbeat. This adds an optional "heartbeat" method to extensions, and, if present, calls a custom method that gets sent to the scheduler and processed by an extension of the same name. This also starts to store the extensions on the worker in a named dictionary. Previously this was a list, but I'm not sure that it was actually used anywhere. This is a breaking change without deprecation, but in a space that I suspect no one will care about. I'm happy to provide a fallback if desired. --- distributed/scheduler.py | 5 +++++ distributed/tests/test_worker.py | 28 ++++++++++++++++++++++++++++ distributed/worker.py | 17 +++++++++++++---- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2c2ae8fa4aa..33e9051b1c0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4328,6 +4328,7 @@ def heartbeat_worker( host_info: dict = None, metrics: dict, executing: dict = None, + extensions: dict = None, ): parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) @@ -4415,6 +4416,10 @@ def heartbeat_worker( if resources: self.add_resources(worker=address, resources=resources) + if extensions: + for name, data in extensions.items(): + self.extensions[name].heartbeat(data) + self.log_event(address, merge({"action": "heartbeat"}, metrics)) return { diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0aefe05bfbc..20f0e2466f5 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3759,3 +3759,31 @@ async def close(): await w.close(executor_wait=True) await asyncio.gather(block(), close(), set_future()) + + +@gen_cluster(nthreads=[]) +async def test_extensions(s): + flag = [False] + + class WorkerExtension: + def __init__(self, worker): + pass + + def heartbeat(self): + return {"data": 123} + + class SchedulerExtension: + def __init__(self, scheduler): + self.scheduler = scheduler + pass + + def heartbeat(self, data: dict): + assert data == {"data": 123} + flag[0] = True + + s.extensions["test"] = SchedulerExtension(s) + + async with Worker(s.address, extensions={"test": WorkerExtension}) as w: + await w.heartbeat() + + assert flag[0] diff --git a/distributed/worker.py b/distributed/worker.py index 918f01d2e5c..a992fa14d54 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -136,7 +136,10 @@ no_value = "--no-value-sentinel--" -DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension, ShuffleWorkerExtension] +DEFAULT_EXTENSIONS: dict[str, type] = { + "pubsub": PubSubWorkerExtension, + "shuffle": ShuffleWorkerExtension, +} DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} @@ -458,7 +461,7 @@ def __init__( memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, max_spill: float | str | Literal[False] | None = None, - extensions: list[type] | None = None, + extensions: dict[str, type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ str, Callable[[Worker], Any] @@ -864,8 +867,9 @@ def __init__( if extensions is None: extensions = DEFAULT_EXTENSIONS - for ext in extensions: - ext(self) + self.extensions = { + name: extension(self) for name, extension in extensions.items() + } self._throttled_gc = ThrottledGC(logger=logger) @@ -1178,6 +1182,11 @@ async def heartbeat(self): for key in self.active_keys if key in self.tasks }, + extensions={ + name: extension.heartbeat() + for name, extension in self.extensions.items() + if hasattr(extension, "heartbeat") + }, ) end = time() middle = (start + end) / 2 From bc81db8972185bf9b5c1688149e8a930da1e5bc2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 12:07:10 -0500 Subject: [PATCH 18/81] Remove file cache --- distributed/shuffle/multi_file.py | 6 ++---- distributed/shuffle/shuffle_extension.py | 10 +++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 3d28c9e3165..a7e339632cf 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -8,8 +8,8 @@ from dask.sizeof import sizeof from dask.utils import parse_bytes -from ..system import MEMORY_LIMIT -from ..utils import log_errors, offload +from distributed.system import MEMORY_LIMIT +from distributed.utils import log_errors, offload class MultiFile: @@ -21,7 +21,6 @@ def __init__( join=None, concurrent_files=1, memory_limit=MEMORY_LIMIT / 2, - file_cache=None, sizeof=sizeof, ): assert join @@ -175,7 +174,6 @@ async def flush(self): def close(self): shutil.rmtree(self.directory) - self.file_cache.clear() def __enter__(self): return self diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index a3c93b6c2a8..b74d1dd6737 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, NewType import toolz -import zict from distributed.protocol import to_serialize from distributed.shuffle.multi_comm import MultiComm @@ -75,7 +74,9 @@ class Shuffle: "State for a single active shuffle" def __init__( - self, metadata: ShuffleMetadata, worker: Worker, file_cache=None + self, + metadata: ShuffleMetadata, + worker: Worker, ) -> None: self.metadata = metadata self.worker = worker @@ -89,7 +90,6 @@ def __init__( load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension - file_cache=file_cache, concurrent_files=2, join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), @@ -194,7 +194,6 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} - self.file_cache = zict.LRU(256, dict(), on_evict=lambda k, v: v.close()) # Handlers ########## @@ -210,7 +209,8 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: f"Shuffle {metadata.id!r} is already registered on worker {self.worker.address}" ) self.shuffles[metadata.id] = Shuffle( - metadata, self.worker, file_cache=self.file_cache + metadata, + self.worker, ) async def shuffle_receive( From 27d2ab366b50aeb6f90e71691965ca0d48680405 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 12:24:21 -0500 Subject: [PATCH 19/81] First pass on adding a Scheduler extension and worker heartbeat --- distributed/scheduler.py | 4 ++- distributed/shuffle/__init__.py | 8 +----- distributed/shuffle/shuffle_extension.py | 32 +++++++++++++++++++++++ distributed/shuffle/tests/test_shuffle.py | 14 ++++++++++ 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 33e9051b1c0..9b2be1e4b14 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -80,6 +80,7 @@ from distributed.recreate_tasks import ReplayTaskScheduler from distributed.security import Security from distributed.semaphore import SemaphoreExtension +from distributed.shuffle import ShuffleSchedulerExtension from distributed.stealing import WorkStealing from distributed.utils import ( All, @@ -183,6 +184,7 @@ def nogil(func): EventExtension, ActiveMemoryManagerExtension, MemorySamplerExtension, + ShuffleSchedulerExtension, ] ALL_TASK_STATES = declare( @@ -4418,7 +4420,7 @@ def heartbeat_worker( if extensions: for name, data in extensions.items(): - self.extensions[name].heartbeat(data) + self.extensions[name].heartbeat(ws, data) self.log_event(address, merge({"action": "heartbeat"}, metrics)) diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index 29d5610d373..7a47cba5040 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -2,12 +2,6 @@ from distributed.shuffle.shuffle_extension import ( ShuffleId, ShuffleMetadata, + ShuffleSchedulerExtension, ShuffleWorkerExtension, ) - -__all__ = [ - "rearrange_by_column_p2p", - "ShuffleId", - "ShuffleMetadata", - "ShuffleWorkerExtension", -] diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index b74d1dd6737..f26eb6d2b58 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -4,6 +4,7 @@ import functools import math import os +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, NewType @@ -109,6 +110,24 @@ def __init__( self.transferred = False self.total_recvd = 0 + def heartbeat(self): + return { + "disk": { + "memory": self.multi_file.total_size, + "buckets": len(self.multi_file.shards), + "written": self.multi_file.bytes_written, + "read": self.multi_file.bytes_read, + "active": len(self.multi_file.active), + }, + "comms": { + "memory": self.multi_comm.total_size, + "buckets": len(self.multi_comm.shards), + "written": self.multi_comm.total_moved, + "read": self.total_recvd, + "active": self.multi_comm.comm_queue.qsize(), # TODO: maybe not built yet + }, + } + async def receive(self, data: list[pa.Buffer]) -> None: # This is actually ok. Our local barrier might have finished, # but barriers on other workers might still be running and sending us @@ -213,6 +232,9 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: self.worker, ) + def heartbeat(self): + return {id: shuffle.heartbeat() for id, shuffle in self.shuffles.items()} + async def shuffle_receive( self, comm: object, @@ -381,6 +403,16 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: ) from None +class ShuffleSchedulerExtension: + def __init__(self, scheduler): + self.scheduler = scheduler + self.shuffles = defaultdict(lambda defaultdict: dict) + + def heartbeat(self, ws, data): + for shuffle_id, d in data.items(): + self.shuffles[shuffle_id][ws.address].update(d) + + def split_by_worker( df: pd.DataFrame, column: str, npartitions: int, workers: list[str] ) -> dict: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index ee03620f96d..a6fbae6a37b 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -18,3 +18,17 @@ async def test_basic(c, s, a, b): a = await a b = await b assert a == b + + +@gen_cluster(client=True) +async def test_heartbeat(c, s, a, b): + await a.heartbeat() + assert not s.extensions["shuffle"].shuffles + df = dask.datasets.timeseries( + dtypes={"x": float, "y": float}, + ) + df = dask.datasets.timeseries() + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + await out.persist() + + assert s.extensions["shuffle"].shuffles From 9fd6da0e503b5b3e1d2300bbde73560459fe53f7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 12:41:15 -0500 Subject: [PATCH 20/81] Name scheduler extensions --- distributed/event.py | 2 -- distributed/lock.py | 2 -- distributed/multi_lock.py | 2 -- distributed/publish.py | 1 - distributed/pubsub.py | 2 -- distributed/queues.py | 2 -- distributed/recreate_tasks.py | 1 - distributed/scheduler.py | 37 ++++++++++++++++---------------- distributed/semaphore.py | 2 -- distributed/stealing.py | 1 - distributed/tests/test_worker.py | 5 +++-- distributed/variable.py | 2 -- 12 files changed, 22 insertions(+), 37 deletions(-) diff --git a/distributed/event.py b/distributed/event.py index 0765e003158..037d171a030 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -58,8 +58,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["events"] = self - async def event_wait(self, name=None, timeout=None): """Wait until the event is set to true. Returns false, when this did not happen in the given time diff --git a/distributed/lock.py b/distributed/lock.py index 5830e2de94b..22e3de5e223 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -30,8 +30,6 @@ def __init__(self, scheduler): {"lock_acquire": self.acquire, "lock_release": self.release} ) - self.scheduler.extensions["locks"] = self - async def acquire(self, name=None, id=None, timeout=None): with log_errors(): if isinstance(name, list): diff --git a/distributed/multi_lock.py b/distributed/multi_lock.py index 31b2e6ebbdb..7907f44ecfc 100644 --- a/distributed/multi_lock.py +++ b/distributed/multi_lock.py @@ -46,8 +46,6 @@ def __init__(self, scheduler): {"multi_lock_acquire": self.acquire, "multi_lock_release": self.release} ) - self.scheduler.extensions["multi_locks"] = self - def _request_locks(self, locks: list[str], id: Hashable, num_locks: int) -> bool: """Request locks diff --git a/distributed/publish.py b/distributed/publish.py index 63772519376..161b025bbc0 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -26,7 +26,6 @@ def __init__(self, scheduler): } self.scheduler.handlers.update(handlers) - self.scheduler.extensions["publish"] = self def put(self, keys=None, data=None, name=None, override=False, client=None): with log_errors(): diff --git a/distributed/pubsub.py b/distributed/pubsub.py index f1cbc62e531..1bfcb490af3 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -34,8 +34,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["pubsub"] = self - def add_publisher(self, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) diff --git a/distributed/queues.py b/distributed/queues.py index c29c4f1ab2c..3dc563b3a52 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -42,8 +42,6 @@ def __init__(self, scheduler): {"queue-future-release": self.future_release, "queue_release": self.release} ) - self.scheduler.extensions["queues"] = self - def create(self, name=None, client=None, maxsize=0): logger.debug(f"Queue name: {name}") if name not in self.queues: diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index 82b72092b43..8bf2d74912d 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -23,7 +23,6 @@ def __init__(self, scheduler): self.scheduler = scheduler self.scheduler.handlers["get_runspec"] = self.get_runspec self.scheduler.handlers["get_error_cause"] = self.get_error_cause - self.scheduler.extensions["replay-tasks"] = self def _process_key(self, key): if isinstance(key, list): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 33e9051b1c0..2eb62b9de71 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -171,19 +171,19 @@ def nogil(func): Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) ) -DEFAULT_EXTENSIONS = [ - LockExtension, - MultiLockExtension, - PublishExtension, - ReplayTaskScheduler, - QueueExtension, - VariableExtension, - PubSubSchedulerExtension, - SemaphoreExtension, - EventExtension, - ActiveMemoryManagerExtension, - MemorySamplerExtension, -] +DEFAULT_EXTENSIONS = { + "locks": LockExtension, + "multi_locks": MultiLockExtension, + "publish": PublishExtension, + "replay-tasks": ReplayTaskScheduler, + "queues": QueueExtension, + "variables": VariableExtension, + "pubsub": PubSubSchedulerExtension, + "semaphores": SemaphoreExtension, + "events": EventExtension, + "amm": ActiveMemoryManagerExtension, + "memory_sampler": MemorySamplerExtension, +} ALL_TASK_STATES = declare( set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} @@ -4011,11 +4011,12 @@ def __init__( self.periodic_callbacks["idle-timeout"] = pc if extensions is None: - extensions = list(DEFAULT_EXTENSIONS) + extensions = DEFAULT_EXTENSIONS.copy() if dask.config.get("distributed.scheduler.work-stealing"): - extensions.append(WorkStealing) - for ext in extensions: - ext(self) + extensions["stealing"] = WorkStealing + self._extensions = { + name: extension(self) for name, extension in extensions.items() + } setproctitle("dask-scheduler [not started]") Scheduler._instances.add(self) @@ -4418,7 +4419,7 @@ def heartbeat_worker( if extensions: for name, data in extensions.items(): - self.extensions[name].heartbeat(data) + self.extensions[name].heartbeat(ws, data) self.log_event(address, merge({"action": "heartbeat"}, metrics)) diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 9e7abd872c0..d288462b706 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -69,8 +69,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["semaphores"] = self - # {metric_name: {semaphore_name: metric}} self.metrics = { "acquire_total": defaultdict(int), # counter diff --git a/distributed/stealing.py b/distributed/stealing.py index 6f957ca2f42..9789cb58b45 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -71,7 +71,6 @@ def __init__(self, scheduler): ) # `callback_time` is in milliseconds self.scheduler.add_plugin(self) - self.scheduler.extensions["stealing"] = self self.scheduler.events["stealing"] = deque(maxlen=100000) self.count = 0 # { task state: } diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 20f0e2466f5..fa21d340ba4 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3762,7 +3762,7 @@ async def close(): @gen_cluster(nthreads=[]) -async def test_extensions(s): +async def test_extension_heartbeat(s): flag = [False] class WorkerExtension: @@ -3777,7 +3777,8 @@ def __init__(self, scheduler): self.scheduler = scheduler pass - def heartbeat(self, data: dict): + def heartbeat(self, ws, data: dict): + assert ws in self.scheduler.workers.values() assert data == {"data": 123} flag[0] = True diff --git a/distributed/variable.py b/distributed/variable.py index a27abc3ab85..143df9e4153 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -40,8 +40,6 @@ def __init__(self, scheduler): self.scheduler.stream_handlers["variable-future-release"] = self.future_release self.scheduler.stream_handlers["variable_delete"] = self.delete - self.scheduler.extensions["variables"] = self - async def set(self, name=None, key=None, data=None, client=None): if key is not None: record = {"type": "Future", "value": key} From 34617dbfc09a20adde55a1828ec891a7bae25570 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 12:44:02 -0500 Subject: [PATCH 21/81] fixup test --- distributed/shuffle/shuffle_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index f26eb6d2b58..de5e28d1cab 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -406,7 +406,7 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: class ShuffleSchedulerExtension: def __init__(self, scheduler): self.scheduler = scheduler - self.shuffles = defaultdict(lambda defaultdict: dict) + self.shuffles = defaultdict(lambda: defaultdict(dict)) def heartbeat(self, ws, data): for shuffle_id, d in data.items(): From e1c0a4d1ef78b2b8ebc1afb7575dc4b189b8832f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 13:47:06 -0500 Subject: [PATCH 22/81] Add timing and diagnostics --- distributed/shuffle/multi_comm.py | 49 ++++++++++++----- distributed/shuffle/multi_file.py | 61 ++++++++++++++------- distributed/shuffle/shuffle_extension.py | 64 +++++++++++++++-------- distributed/shuffle/tests/test_shuffle.py | 10 ++-- 4 files changed, 126 insertions(+), 58 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index a4d8281cbfc..001ede98c84 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -1,14 +1,16 @@ import asyncio +import contextlib import threading +import time from collections import defaultdict from dask.utils import parse_bytes -from ..core import rpc -from ..protocol import to_serialize -from ..sizeof import sizeof -from ..system import MEMORY_LIMIT -from ..utils import log_errors +from distributed.core import rpc +from distributed.protocol import to_serialize +from distributed.sizeof import sizeof +from distributed.system import MEMORY_LIMIT +from distributed.utils import log_errors class MultiComm: @@ -38,6 +40,7 @@ def __init__( self._futures = set() self._done = False self.rpc = rpc + self.diagnostics = defaultdict(float) def put(self, data: dict): with self.lock: @@ -61,8 +64,9 @@ def put(self, data: dict): ) while self.total_size > self.memory_limit: - with self.thread_condition: - self.thread_condition.wait(1) # Block until memory calms down + with self.time("waiting-on-memory"): + with self.thread_condition: + self.thread_condition.wait(1) # Block until memory calms down async def communicate(self): self.comm_queue = asyncio.Queue(maxsize=self.max_connections) @@ -72,11 +76,12 @@ async def communicate(self): from dask.utils import format_bytes while not self._done: - if not self.shards: - await asyncio.sleep(0.1) - continue + with self.time("idle"): + if not self.shards: + await asyncio.sleep(0.1) + continue - await self.comm_queue.get() + await self.comm_queue.get() with self.lock: address = max(self.sizes, key=self.sizes.get) @@ -121,10 +126,19 @@ async def process(self, address: str, shards: list, size: int): # Consider boosting total_size a bit here to account for duplication try: - await self.rpc(address).shuffle_receive( - data=to_serialize(shards), - shuffle_id=self.shuffle_id, + start = time.time() + with self.time("send"): + await self.rpc(address).shuffle_receive( + data=to_serialize(shards), + shuffle_id=self.shuffle_id, + ) + stop = time.time() + self.diagnostics["avg_size"] = ( + 0.95 * self.diagnostics["avg_size"] + 0.05 * size ) + self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ + "avg_duration" + ] + 0.02 * (stop - start) finally: self.total_size -= size with self.thread_condition: @@ -146,3 +160,10 @@ async def flush(self): print("total moved", format_bytes(self.total_moved)) self._done = True + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index a7e339632cf..fcee55eafd6 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -1,8 +1,10 @@ import asyncio +import contextlib import os import pathlib import pickle import shutil +import time from collections import defaultdict from dask.sizeof import sizeof @@ -47,6 +49,7 @@ def __init__( self._done = False self._futures = set() self.active = set() + self.diagnostics = defaultdict(float) async def put(self, data: dict): this_size = 0 @@ -71,14 +74,15 @@ async def put(self, data: dict): ) while self.total_size > self.memory_limit: - async with self.condition: + with self.time("waiting-on-memory"): + async with self.condition: - try: - await asyncio.wait_for( - self.condition.wait(), 1 - ) # Block until memory calms down - except asyncio.TimeoutError: - continue + try: + await asyncio.wait_for( + self.condition.wait(), 1 + ) # Block until memory calms down + except asyncio.TimeoutError: + continue async def communicate(self): with log_errors(): @@ -87,11 +91,12 @@ async def communicate(self): self.queue.put_nowait(None) while not self._done: - if not self.shards: - await asyncio.sleep(0.1) - continue + with self.time("idle"): + if not self.shards: + await asyncio.sleep(0.1) + continue - await self.queue.get() + await self.queue.get() id = max(self.sizes, key=self.sizes.get) shards = self.shards.pop(id) @@ -131,7 +136,17 @@ def _(): self.dump(shard, f) os.fsync(f) # TODO: maybe? - await offload(_) + start = time.time() + with self.time("write"): + await offload(_) + stop = time.time() + + self.diagnostics["avg_size"] = ( + 0.98 * self.diagnostics["avg_size"] + 0.02 * size + ) + self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ + "avg_duration" + ] + 0.02 * (stop - start) self.active.remove(id) self.total_size -= size @@ -143,12 +158,15 @@ def read(self, id): parts = [] try: - with open(self.directory / str(id), mode="rb", buffering=100_000_000) as f: - while True: - try: - parts.append(self.load(f)) - except EOFError: - break + with self.time("read"): + with open( + self.directory / str(id), mode="rb", buffering=100_000_000 + ) as f: + while True: + try: + parts.append(self.load(f)) + except EOFError: + break except FileNotFoundError: raise KeyError(id) @@ -180,3 +198,10 @@ def __enter__(self): def __exit__(self, exc, typ, traceback): self.close() + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index de5e28d1cab..6c1e53285b3 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -1,9 +1,11 @@ from __future__ import annotations import asyncio +import contextlib import functools import math import os +import time from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, NewType @@ -106,9 +108,18 @@ def __init__( self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) + self.diagnostics: dict[str, float] = defaultdict(float) self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False self.total_recvd = 0 + self.start_time = time.time() + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start def heartbeat(self): return { @@ -118,6 +129,7 @@ def heartbeat(self): "written": self.multi_file.bytes_written, "read": self.multi_file.bytes_read, "active": len(self.multi_file.active), + "diagnostics": self.multi_file.diagnostics, }, "comms": { "memory": self.multi_comm.total_size, @@ -125,7 +137,10 @@ def heartbeat(self): "written": self.multi_comm.total_moved, "read": self.total_recvd, "active": self.multi_comm.comm_queue.qsize(), # TODO: maybe not built yet + "diagnostics": self.multi_comm.diagnostics, }, + "diagnostics": self.diagnostics, + "start": self.start_time, } async def receive(self, data: list[pa.Buffer]) -> None: @@ -140,34 +155,40 @@ async def receive(self, data: list[pa.Buffer]) -> None: print("recved", format_bytes(sum(map(len, data)))) self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table - data = await offload( - list_of_buffers_to_table, - data, - schema=pa.Schema.from_pandas(self.metadata.empty), - ) + with self.time("cpu"): + data = await offload( + list_of_buffers_to_table, + data, + schema=pa.Schema.from_pandas(self.metadata.empty), + ) - groups = await offload(split_by_partition, data, self.metadata.column) + groups = await offload(split_by_partition, data, self.metadata.column) assert len(data) == sum(map(len, groups.values())) del data - groups = await offload( - lambda: { - k: [batch.serialize() for batch in v.to_batches()] - for k, v in groups.items() - } - ) # TODO: consider offloading + with self.time("cpu"): + groups = await offload( + lambda: { + k: [batch.serialize() for batch in v.to_batches()] + for k, v in groups.items() + } + ) # TODO: consider offloading await self.multi_file.put(groups) def add_partition(self, data: pd.DataFrame) -> None: - out = split_by_worker( - data, self.metadata.column, self.metadata.npartitions, self.metadata.workers - ) - assert len(data) == sum(map(len, out.values())) - out = { - k: [b.serialize().to_pybytes() for b in t.to_batches()] - for k, t in out.items() - } + with self.time("cpu"): + out = split_by_worker( + data, + self.metadata.column, + self.metadata.npartitions, + self.metadata.workers, + ) + assert len(data) == sum(map(len, out.values())) + out = { + k: [b.serialize().to_pybytes() for b in t.to_batches()] + for k, t in out.items() + } self.multi_comm.put(out) def get_output_partition(self, i: int) -> pd.DataFrame: @@ -188,7 +209,8 @@ def get_output_partition(self, i: int) -> pd.DataFrame: sync(self.worker.loop, self.multi_file.flush) # type: ignore try: df = self.multi_file.read(i) - return df.to_pandas() + with self.time("cpu"): + return df.to_pandas() except KeyError: return self.metadata.empty.head(0) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index a6fbae6a37b..5e10138a18c 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -14,10 +14,10 @@ async def test_basic(c, s, a, b): ) df = dask.datasets.timeseries() out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - a, b = c.compute([df.x.size, out.x.size]) - a = await a - b = await b - assert a == b + x, y = c.compute([df.x.size, out.x.size]) + x = await x + y = await y + assert x == y @gen_cluster(client=True) @@ -31,4 +31,4 @@ async def test_heartbeat(c, s, a, b): out = dd.shuffle.shuffle(df, "x", shuffle="p2p") await out.persist() - assert s.extensions["shuffle"].shuffles + [s] = s.extensions["shuffle"].shuffles.values() From 8c28b83eb8fce606232d0c5a11ed1a89ab2d5404 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 17 Mar 2022 16:46:33 -0500 Subject: [PATCH 23/81] fixup tests --- distributed/scheduler.py | 6 ++++-- distributed/tests/test_worker_client.py | 6 +----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2eb62b9de71..5e68f6eab3a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -183,6 +183,7 @@ def nogil(func): "events": EventExtension, "amm": ActiveMemoryManagerExtension, "memory_sampler": MemorySamplerExtension, + "stealing": WorkStealing, } ALL_TASK_STATES = declare( @@ -4012,8 +4013,9 @@ def __init__( if extensions is None: extensions = DEFAULT_EXTENSIONS.copy() - if dask.config.get("distributed.scheduler.work-stealing"): - extensions["stealing"] = WorkStealing + if not dask.config.get("distributed.scheduler.work-stealing"): + if "stealing" in extensions: + del extensions["stealing"] self._extensions = { name: extension(self) for name, extension in extensions.items() } diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index ecdfa8fd003..468bd90d463 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -254,13 +254,9 @@ def test_secede_without_stealing_issue_1262(): Tests that seceding works with the Stealing extension disabled https://github.com/dask/distributed/issues/1262 """ - - # turn off all extensions - extensions = [] - # run the loop as an inner function so all workers are closed # and exceptions can be examined - @gen_cluster(client=True, scheduler_kwargs={"extensions": extensions}) + @gen_cluster(client=True, scheduler_kwargs={"extensions": {}}) async def secede_test(c, s, a, b): def func(x): with worker_client() as wc: From 29253b47687aa589d11af3d626c1416ae6ae90a8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 11:26:49 -0500 Subject: [PATCH 24/81] Use names for client extensions --- distributed/client.py | 10 +++++++--- distributed/pubsub.py | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ce7652e762b..1407cb2b03b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -115,7 +115,10 @@ _current_client = ContextVar("_current_client", default=None) -DEFAULT_EXTENSIONS = [PubSubClientExtension] +DEFAULT_EXTENSIONS = { + "pubsub": PubSubClientExtension, +} + # Placeholder used in the get_dataset function(s) NO_DEFAULT_PLACEHOLDER = "_no_default_" @@ -928,8 +931,9 @@ def __init__( server=self, ) - for ext in extensions: - ext(self) + self.extensions = { + name: extension(self) for name, extension in extensions.items() + } preload = dask.config.get("distributed.client.preload") preload_argv = dask.config.get("distributed.client.preload-argv") diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 1bfcb490af3..f575439c3a0 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -176,7 +176,6 @@ def __init__(self, client): self.client._stream_handlers.update({"pubsub-msg": self.handle_message}) self.subscribers = defaultdict(weakref.WeakSet) - self.client.extensions["pubsub"] = self # TODO: circular reference async def handle_message(self, name=None, msg=None): for sub in self.subscribers[name]: From 1f7957543c4987134887d1709d3df93839051a5d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 11:44:44 -0500 Subject: [PATCH 25/81] Add back in manual addition of stealing extension Tests are failing. I can't reproduce locally. This is just blind praying that it fixes the problem. It should be innocuous. --- distributed/stealing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/stealing.py b/distributed/stealing.py index 9789cb58b45..bc51cabee38 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -80,6 +80,7 @@ def __init__(self, scheduler): self._in_flight_event = asyncio.Event() self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm + self.scheduler.extensions["stealing"] = self async def start(self, scheduler=None): """Start the background coroutine to balance the tasks on the cluster. From 6f2286e4ecb49c06472d25d802eee674baa91b6a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 12:38:15 -0500 Subject: [PATCH 26/81] Add basic shuffling dashboard --- distributed/dashboard/components/scheduler.py | 207 ++++++++++++++++++ distributed/dashboard/scheduler.py | 2 + .../dashboard/tests/test_scheduler_bokeh.py | 16 ++ distributed/shuffle/shuffle_extension.py | 2 + 4 files changed, 227 insertions(+) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 1dc529b19f8..28a52a4a7c4 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3284,6 +3284,201 @@ def update(self): self.source.data.update(data) +class Shuffling(DashboardComponent): + """Occupancy (in time) per worker""" + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "worker": [], + "y": [], + "comm_memory": [], + "comm_memory_half": [], + "comm_memory_limit": [], + "comm_buckets": [], + "comm_active": [], + "comm_avg_duration": [], + "comm_avg_size": [], + "comm_read": [], + "comm_written": [], + "disk_memory": [], + "disk_memory_half": [], + "disk_memory_limit": [], + "disk_buckets": [], + "disk_active": [], + "disk_avg_duration": [], + "disk_avg_size": [], + "disk_read": [], + "disk_written": [], + } + ) + + self.comm_memory = figure( + title="Comms Buffer", + tools="", + toolbar_location="above", + ) + self.comm_memory.rect( + source=self.source, + x="comm_memory_half", + width="comm_memory", + y="y", + height=0.9, + ) + hover = HoverTool( + tooltips=""" +
+ Memory Used:  + @comm_memory{0.00 b} +
+
+ Average Write:  + @comm_avg_size{0.00 b} +
+
+ # Buckets:  + @comm_buckets +
+
+ Average Duration:  + @comm_avg_duration +
+ """, + ) + hover.point_policy = "follow_mouse" + self.comm_memory.add_tools(hover) + self.comm_memory.x_range.start = 0 + self.comm_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + self.disk_memory = figure( + title="Disk Buffer", + tools="", + toolbar_location="above", + ) + self.disk_memory.yaxis.visible = False + + self.disk_memory.rect( + source=self.source, + x="disk_memory_half", + width="disk_memory", + y="y", + height=0.9, + ) + + hover = HoverTool( + tooltips=""" +
+ Memory Used:  + @disk_memory{0.00 b} +
+
+ Average Write:  + @disk_avg_size{0.00 b} +
+
+ # Buckets:  + @disk_buckets +
+
+ Average Duration:  + @disk_avg_duration +
+ """, + ) + hover.point_policy = "follow_mouse" + self.disk_memory.add_tools(hover) + self.disk_memory.x_range.start = 0 + self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + self.root = row(self.comm_memory, self.disk_memory) + + @without_property_validation + def update(self): + with log_errors(): + input = self.scheduler.extensions["shuffle"].shuffles + if not input: + return + + input = list(input.values())[-1] # TODO: multiple concurrent shuffles + + data = { + "worker": [], + "y": [], + "comm_memory": [], + "comm_memory_half": [], + "comm_memory_limit": [], + "comm_buckets": [], + "comm_active": [], + "comm_avg_duration": [], + "comm_avg_size": [], + "comm_read": [], + "comm_written": [], + "disk_memory": [], + "disk_memory_half": [], + "disk_memory_limit": [], + "disk_buckets": [], + "disk_active": [], + "disk_avg_duration": [], + "disk_avg_size": [], + "disk_read": [], + "disk_written": [], + } + + for i, (worker, d) in enumerate(input.items()): + data["y"].append(i) + data["worker"].append(worker) + data["comm_memory"].append(d["comms"]["memory"]) + data["comm_memory_half"].append(d["comms"]["memory"] / 2) + data["comm_memory_limit"].append(d["comms"]["memory_limit"]) + data["comm_buckets"].append(d["comms"]["buckets"]) + data["comm_active"].append(d["comms"]["active"]) + data["comm_avg_duration"].append( + d["comms"]["diagnostics"]["avg_duration"] + ) + data["comm_avg_size"].append(d["comms"]["diagnostics"]["avg_size"]) + data["comm_read"].append(d["comms"]["read"]) + data["comm_written"].append(d["comms"]["written"]) + + data["disk_memory"].append(d["disk"]["memory"]) + data["disk_memory_half"].append(d["disk"]["memory"] / 2) + data["disk_memory_limit"].append(d["disk"]["memory_limit"]) + data["disk_buckets"].append(d["disk"]["buckets"]) + data["disk_active"].append(d["disk"]["active"]) + data["disk_avg_duration"].append( + d["disk"]["diagnostics"]["avg_duration"] + ) + data["disk_avg_size"].append(d["disk"]["diagnostics"]["avg_size"]) + data["disk_read"].append(d["disk"]["read"]) + data["disk_written"].append(d["disk"]["written"]) + + singletons = { + "comm_avg_duration": [ + sum(data["comm_avg_duration"]) / len(data["comm_avg_duration"]) + ], + "comm_avg_size": [ + sum(data["comm_avg_size"]) / len(data["comm_avg_size"]) + ], + "disk_avg_duration": [ + sum(data["disk_avg_duration"]) / len(data["disk_avg_duration"]) + ], + "disk_avg_size": [ + sum(data["disk_avg_size"]) / len(data["disk_avg_size"]) + ], + } + singletons["comm_avg_bandwidth"] = [ + singletons["comm_avg_size"][0] / singletons["comm_avg_duration"][0] + ] + singletons["disk_avg_bandwidth"] = [ + singletons["disk_avg_size"][0] / singletons["disk_avg_duration"][0] + ] + singletons["y"] = [data["y"][-1] / 2] + + update(self.source, data) + self.comm_memory.x_range.end = max(data["comm_memory_limit"]) * 1.2 + self.disk_memory.x_range.end = max(data["disk_memory_limit"]) * 1.2 + + class SchedulerLogs: def __init__(self, scheduler, start=None): logs = scheduler.get_logs(start=start, timestamps=True) @@ -3328,6 +3523,18 @@ def systemmonitor_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME +def shuffling_doc(scheduler, extra, doc): + with log_errors(): + shuffling = Shuffling(scheduler, sizing_mode="stretch_both") + doc.title = "Dask: Shuffling" + add_periodic_callback(doc, shuffling, 500) + + doc.add_root(shuffling.root) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + def stealing_doc(scheduler, extra, doc): with log_errors(): occupancy = Occupancy(scheduler) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 3d8e62d95ff..a46d4d05d5a 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -35,6 +35,7 @@ individual_profile_server_doc, profile_doc, profile_server_doc, + shuffling_doc, status_doc, stealing_doc, systemmonitor_doc, @@ -47,6 +48,7 @@ applications = { "/system": systemmonitor_doc, + "/shuffle": shuffling_doc, "/stealing": stealing_doc, "/workers": workers_doc, "/events": events_doc, diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 2907482e078..325406ca166 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -28,6 +28,7 @@ Occupancy, ProcessingHistogram, ProfileServer, + Shuffling, StealingEvents, StealingTimeSeries, SystemMonitor, @@ -991,3 +992,18 @@ async def test_prefix_bokeh(s, a, b): bokeh_app = s.http_application.applications[0] assert isinstance(bokeh_app, BokehTornado) assert bokeh_app.prefix == f"/{prefix}" + + +@gen_cluster(client=True, worker_kwargs={"dashboard": True}) +async def test_shuffling(c, s, a, b): + dd = pytest.importorskip("dask.dataframe") + ss = Shuffling(s) + + df = dask.datasets.timeseries() + df2 = dd.shuffle.shuffle(df, "x", shuffle="p2p").persist() + + start = time() + while not ss.source.data["disk_read"]: + ss.update() + await asyncio.sleep(0.1) + assert time() < start + 5 diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 6c1e53285b3..7af1e6eab88 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -130,6 +130,7 @@ def heartbeat(self): "read": self.multi_file.bytes_read, "active": len(self.multi_file.active), "diagnostics": self.multi_file.diagnostics, + "memory_limit": self.multi_file.memory_limit, }, "comms": { "memory": self.multi_comm.total_size, @@ -138,6 +139,7 @@ def heartbeat(self): "read": self.total_recvd, "active": self.multi_comm.comm_queue.qsize(), # TODO: maybe not built yet "diagnostics": self.multi_comm.diagnostics, + "memory_limit": self.multi_comm.memory_limit, }, "diagnostics": self.diagnostics, "start": self.start_time, From 456de23303bf55d7abeaba5221e95e18c2bcd88a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 13:04:36 -0500 Subject: [PATCH 27/81] Add colors to shuffling plots --- distributed/dashboard/components/scheduler.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 28a52a4a7c4..23289c3f082 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3319,6 +3319,7 @@ def __init__(self, scheduler, **kwargs): title="Comms Buffer", tools="", toolbar_location="above", + x_range=Range1d(0, 100_000_000), ) self.comm_memory.rect( source=self.source, @@ -3326,6 +3327,7 @@ def __init__(self, scheduler, **kwargs): width="comm_memory", y="y", height=0.9, + color="comm_color", ) hover = HoverTool( tooltips=""" @@ -3350,12 +3352,14 @@ def __init__(self, scheduler, **kwargs): hover.point_policy = "follow_mouse" self.comm_memory.add_tools(hover) self.comm_memory.x_range.start = 0 + self.comm_memory.x_range.end = 1 self.comm_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") self.disk_memory = figure( title="Disk Buffer", tools="", toolbar_location="above", + x_range=Range1d(0, 100_000_000), ) self.disk_memory.yaxis.visible = False @@ -3365,6 +3369,7 @@ def __init__(self, scheduler, **kwargs): width="disk_memory", y="y", height=0.9, + color="disk_color", ) hover = HoverTool( @@ -3389,7 +3394,6 @@ def __init__(self, scheduler, **kwargs): ) hover.point_policy = "follow_mouse" self.disk_memory.add_tools(hover) - self.disk_memory.x_range.start = 0 self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") self.root = row(self.comm_memory, self.disk_memory) @@ -3414,6 +3418,7 @@ def update(self): "comm_avg_size": [], "comm_read": [], "comm_written": [], + "comm_color": [], "disk_memory": [], "disk_memory_half": [], "disk_memory_limit": [], @@ -3423,6 +3428,7 @@ def update(self): "disk_avg_size": [], "disk_read": [], "disk_written": [], + "disk_color": [], } for i, (worker, d) in enumerate(input.items()): @@ -3439,6 +3445,12 @@ def update(self): data["comm_avg_size"].append(d["comms"]["diagnostics"]["avg_size"]) data["comm_read"].append(d["comms"]["read"]) data["comm_written"].append(d["comms"]["written"]) + if d["comms"]["active"]: + data["comm_color"].append("green") + elif d["comms"]["memory"] > d["comms"]["memory_limit"]: + data["comm_color"].append("red") + else: + data["comm_color"].append("blue") data["disk_memory"].append(d["disk"]["memory"]) data["disk_memory_half"].append(d["disk"]["memory"] / 2) @@ -3451,6 +3463,12 @@ def update(self): data["disk_avg_size"].append(d["disk"]["diagnostics"]["avg_size"]) data["disk_read"].append(d["disk"]["read"]) data["disk_written"].append(d["disk"]["written"]) + if d["disk"]["active"]: + data["disk_color"].append("green") + elif d["comms"]["memory"] > d["comms"]["memory_limit"]: + data["disk_color"].append("red") + else: + data["disk_color"].append("blue") singletons = { "comm_avg_duration": [ From aef2f61d2a1dc80dec8b8695a8c97f7fddfac88d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 13:21:03 -0500 Subject: [PATCH 28/81] make larger dashboard page --- distributed/dashboard/components/scheduler.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 23289c3f082..d5baf6ed26c 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3320,6 +3320,7 @@ def __init__(self, scheduler, **kwargs): tools="", toolbar_location="above", x_range=Range1d(0, 100_000_000), + **kwargs, ) self.comm_memory.rect( source=self.source, @@ -3360,6 +3361,7 @@ def __init__(self, scheduler, **kwargs): tools="", toolbar_location="above", x_range=Range1d(0, 100_000_000), + **kwargs, ) self.disk_memory.yaxis.visible = False @@ -3543,11 +3545,24 @@ def systemmonitor_doc(scheduler, extra, doc): def shuffling_doc(scheduler, extra, doc): with log_errors(): - shuffling = Shuffling(scheduler, sizing_mode="stretch_both") doc.title = "Dask: Shuffling" - add_periodic_callback(doc, shuffling, 500) - doc.add_root(shuffling.root) + shuffling = Shuffling(scheduler, width=400, height=400) + workers_memory = WorkersMemory(scheduler, width=400, height=400) + timeseries = SystemTimeseries(scheduler, width=1200, height=200) + + add_periodic_callback(doc, shuffling, 200) + add_periodic_callback(doc, workers_memory, 200) + add_periodic_callback(doc, timeseries, 500) + + timeseries.bandwidth.y_range = timeseries.disk.y_range + + doc.add_root( + column( + row(workers_memory.root, shuffling.comm_memory, shuffling.disk_memory), + row(column(timeseries.bandwidth, timeseries.disk)), + ) + ) doc.template = env.get_template("simple.html") doc.template_variables.update(extra) doc.theme = BOKEH_THEME From f79e923ca1367046166c1d85ca14858cf36e5620 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 13:45:06 -0500 Subject: [PATCH 29/81] extend shuffling dashboard --- distributed/dashboard/components/scheduler.py | 89 ++++++++++++++++--- distributed/shuffle/multi_file.py | 1 + distributed/shuffle/shuffle_extension.py | 4 +- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index d5baf6ed26c..c0dae4a00b3 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -893,7 +893,7 @@ class SystemTimeseries(DashboardComponent): from ws.metrics["val"] for ws in scheduler.workers.values() divided by nuber of workers. """ - def __init__(self, scheduler, **kwargs): + def __init__(self, scheduler, follow_interval=20000, **kwargs): with log_errors(): self.scheduler = scheduler self.source = ColumnDataSource( @@ -910,7 +910,9 @@ def __init__(self, scheduler, **kwargs): update(self.source, self.get_data()) - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + x_range = DataRange1d( + follow="end", follow_interval=follow_interval, range_padding=0 + ) tools = "reset, xpan, xwheel_zoom" self.bandwidth = figure( @@ -3314,6 +3316,12 @@ def __init__(self, scheduler, **kwargs): "disk_written": [], } ) + self.totals_source = ColumnDataSource( + { + "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], + "value": [0, 0, 0, 0], + } + ) self.comm_memory = figure( title="Comms Buffer", @@ -3397,6 +3405,44 @@ def __init__(self, scheduler, **kwargs): hover.point_policy = "follow_mouse" self.disk_memory.add_tools(hover) self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + self.totals = figure( + title="Total movement", + tools="", + toolbar_location="above", + **kwargs, + ) + titles = ["Network Send", "Network Receive", "Disk Write", "Disk Read"] + self.totals = figure( + x_range=titles, + title="Totals", + toolbar_location=None, + tools="", + **kwargs, + ) + + self.totals.vbar( + x="x", + top="values", + width=0.9, + source=self.totals_source, + ) + + self.totals.xgrid.grid_line_color = None + self.totals.y_range.start = 0 + self.totals.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + hover = HoverTool( + tooltips=""" +
+ @x:  + @values{0.00 b} +
+ """, + ) + hover.point_policy = "follow_mouse" + self.totals.add_tools(hover) + self.root = row(self.comm_memory, self.disk_memory) @without_property_validation @@ -3442,9 +3488,11 @@ def update(self): data["comm_buckets"].append(d["comms"]["buckets"]) data["comm_active"].append(d["comms"]["active"]) data["comm_avg_duration"].append( - d["comms"]["diagnostics"]["avg_duration"] + d["comms"]["diagnostics"].get("avg_duration", 0) + ) + data["comm_avg_size"].append( + d["comms"]["diagnostics"].get("avg_size", 0) ) - data["comm_avg_size"].append(d["comms"]["diagnostics"]["avg_size"]) data["comm_read"].append(d["comms"]["read"]) data["comm_written"].append(d["comms"]["written"]) if d["comms"]["active"]: @@ -3460,9 +3508,11 @@ def update(self): data["disk_buckets"].append(d["disk"]["buckets"]) data["disk_active"].append(d["disk"]["active"]) data["disk_avg_duration"].append( - d["disk"]["diagnostics"]["avg_duration"] + d["disk"]["diagnostics"].get("avg_duration", 0) + ) + data["disk_avg_size"].append( + d["disk"]["diagnostics"].get("avg_size", 0) ) - data["disk_avg_size"].append(d["disk"]["diagnostics"]["avg_size"]) data["disk_read"].append(d["disk"]["read"]) data["disk_written"].append(d["disk"]["written"]) if d["disk"]["active"]: @@ -3494,9 +3544,21 @@ def update(self): ] singletons["y"] = [data["y"][-1] / 2] + totals = { + "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], + "values": [ + sum(data["comm_written"]), + sum(data["comm_read"]), + sum(data["disk_written"]), + sum(data["disk_read"]), + ], + } + update(self.totals_source, totals) + update(self.source, data) - self.comm_memory.x_range.end = max(data["comm_memory_limit"]) * 1.2 - self.disk_memory.x_range.end = max(data["disk_memory_limit"]) * 1.2 + limit = max(data["comm_memory_limit"] + data["disk_memory_limit"]) * 1.2 + self.comm_memory.x_range.end = limit + self.disk_memory.x_range.end = limit class SchedulerLogs: @@ -3549,7 +3611,9 @@ def shuffling_doc(scheduler, extra, doc): shuffling = Shuffling(scheduler, width=400, height=400) workers_memory = WorkersMemory(scheduler, width=400, height=400) - timeseries = SystemTimeseries(scheduler, width=1200, height=200) + timeseries = SystemTimeseries( + scheduler, width=1400, height=200, follow_interval=3000 + ) add_periodic_callback(doc, shuffling, 200) add_periodic_callback(doc, workers_memory, 200) @@ -3559,7 +3623,12 @@ def shuffling_doc(scheduler, extra, doc): doc.add_root( column( - row(workers_memory.root, shuffling.comm_memory, shuffling.disk_memory), + row( + workers_memory.root, + shuffling.comm_memory, + shuffling.disk_memory, + shuffling.totals, + ), row(column(timeseries.bandwidth, timeseries.disk)), ) ) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index fcee55eafd6..5581f44e516 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -149,6 +149,7 @@ def _(): ] + 0.02 * (stop - start) self.active.remove(id) + self.bytes_written += size self.total_size -= size async with self.condition: self.condition.notify() diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 7af1e6eab88..50fb816f805 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -270,8 +270,8 @@ async def shuffle_receive( Using an unknown ``shuffle_id`` is an error. """ shuffle = self._get_shuffle(shuffle_id) - await shuffle.receive(data) - return + # await shuffle.receive(data) + # return # TODO: it would be good to not have comms wait on disk if not # necessary future = asyncio.ensure_future(shuffle.receive(data)) From 1e0256f7546faebfe0b9665df870368629bbd237 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 16:39:04 -0500 Subject: [PATCH 30/81] Don't offload file writes also remove printing --- distributed/dashboard/components/scheduler.py | 33 ++++++++----------- distributed/shuffle/multi_comm.py | 28 ++-------------- distributed/shuffle/multi_file.py | 33 ++++--------------- distributed/shuffle/shuffle_extension.py | 5 +-- 4 files changed, 23 insertions(+), 76 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c0dae4a00b3..47253ca1e31 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3305,6 +3305,7 @@ def __init__(self, scheduler, **kwargs): "comm_avg_size": [], "comm_read": [], "comm_written": [], + "comm_color": [], "disk_memory": [], "disk_memory_half": [], "disk_memory_limit": [], @@ -3314,12 +3315,13 @@ def __init__(self, scheduler, **kwargs): "disk_avg_size": [], "disk_read": [], "disk_written": [], + "disk_color": [], } ) self.totals_source = ColumnDataSource( { "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], - "value": [0, 0, 0, 0], + "values": [0, 0, 0, 0], } ) @@ -3339,24 +3341,13 @@ def __init__(self, scheduler, **kwargs): color="comm_color", ) hover = HoverTool( - tooltips=""" -
- Memory Used:  - @comm_memory{0.00 b} -
-
- Average Write:  - @comm_avg_size{0.00 b} -
-
- # Buckets:  - @comm_buckets -
-
- Average Duration:  - @comm_avg_duration -
- """, + tooltips=[ + ("Memory Used", "@comm_memory{0.00 b}"), + ("Average Write", "@comm_avg_size{0.00 b}"), + ("# Buckets", "@comm_buckets"), + ("Average Duration", "@comm_avg_duration"), + ], + formatters={"@comm_avg_duration": "datetime"}, ) hover.point_policy = "follow_mouse" self.comm_memory.add_tools(hover) @@ -3517,11 +3508,12 @@ def update(self): data["disk_written"].append(d["disk"]["written"]) if d["disk"]["active"]: data["disk_color"].append("green") - elif d["comms"]["memory"] > d["comms"]["memory_limit"]: + elif d["disk"]["memory"] > d["disk"]["memory_limit"]: data["disk_color"].append("red") else: data["disk_color"].append("blue") + """ singletons = { "comm_avg_duration": [ sum(data["comm_avg_duration"]) / len(data["comm_avg_duration"]) @@ -3543,6 +3535,7 @@ def update(self): singletons["disk_avg_size"][0] / singletons["disk_avg_duration"][0] ] singletons["y"] = [data["y"][-1] / 2] + """ totals = { "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 001ede98c84..2906365fdd2 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -53,16 +53,6 @@ def put(self, data: dict): del data, shard - from dask.utils import format_bytes - - if self.total_size > self.memory_limit: - print( - "waiting comm", - format_bytes(self.total_size), - "this", - format_bytes(size), - ) - while self.total_size > self.memory_limit: with self.time("waiting-on-memory"): with self.thread_condition: @@ -73,8 +63,6 @@ async def communicate(self): for _ in range(self.max_connections): self.comm_queue.put_nowait(None) - from dask.utils import format_bytes - while not self._done: with self.time("idle"): if not self.shards: @@ -105,15 +93,6 @@ async def communicate(self): assert set(self.sizes) == set(self.shards) assert shards - print( - "Sending", - format_bytes(size), - "to comm", - format_bytes(self.total_size), - "left in ", - len(self.shards), - "buckets", - ) future = asyncio.ensure_future(self.process(address, shards, size)) del shards self._futures.add(future) @@ -126,10 +105,12 @@ async def process(self, address: str, shards: list, size: int): # Consider boosting total_size a bit here to account for duplication try: + # while (time.time() // 5 % 4) == 0: + # await asyncio.sleep(0.1) start = time.time() with self.time("send"): await self.rpc(address).shuffle_receive( - data=to_serialize(shards), + data=to_serialize([b"".join(shards)]), shuffle_id=self.shuffle_id, ) stop = time.time() @@ -155,9 +136,6 @@ async def flush(self): if self.total_size: breakpoint() assert not self.total_size - from dask.utils import format_bytes - - print("total moved", format_bytes(self.total_moved)) self._done = True diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 5581f44e516..56ad9b67a10 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -11,7 +11,7 @@ from dask.utils import parse_bytes from distributed.system import MEMORY_LIMIT -from distributed.utils import log_errors, offload +from distributed.utils import log_errors class MultiFile: @@ -63,16 +63,6 @@ async def put(self, data: dict): del data, shard - from dask.utils import format_bytes - - if self.total_size > self.memory_limit: - print( - "waiting disk", - format_bytes(self.total_size), - "this", - format_bytes(this_size), - ) - while self.total_size > self.memory_limit: with self.time("waiting-on-memory"): async with self.condition: @@ -101,17 +91,6 @@ async def communicate(self): id = max(self.sizes, key=self.sizes.get) shards = self.shards.pop(id) size = self.sizes.pop(id) - from dask.utils import format_bytes - - print( - "Writing", - format_bytes(size), - "to disk", - format_bytes(self.total_size), - "left in", - len(self.shards), - "buckets", - ) future = asyncio.ensure_future(self.process(id, shards, size)) del shards @@ -128,17 +107,17 @@ async def process(self, id: str, shards: list, size: int): self.active.add(id) def _(): - # TODO: offload with open( self.directory / str(id), mode="ab", buffering=100_000_000 ) as f: for shard in shards: self.dump(shard, f) - os.fsync(f) # TODO: maybe? + # os.fsync(f) # TODO: maybe? start = time.time() with self.time("write"): - await offload(_) + _() + # await offload(_) stop = time.time() self.diagnostics["avg_size"] = ( @@ -168,13 +147,13 @@ def read(self, id): parts.append(self.load(f)) except EOFError: break + size = f.tell() except FileNotFoundError: raise KeyError(id) # TODO: We could consider deleting the file at this point if parts: - for part in parts: - self.bytes_read += sizeof(part) + self.bytes_read += size return self.join(parts) else: raise KeyError(id) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 50fb816f805..892d4e69382 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -98,7 +98,7 @@ def __init__( sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="200 MiB", # TODO + memory_limit="500 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), @@ -152,9 +152,6 @@ async def receive(self, data: list[pa.Buffer]) -> None: # assert not self.transferred, "`receive` called after barrier task" import pyarrow as pa - from dask.utils import format_bytes - - print("recved", format_bytes(sum(map(len, data)))) self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table with self.time("cpu"): From 97fb09c0657a74984270ef691073df3e7a056585 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 16:49:57 -0500 Subject: [PATCH 31/81] reduce comm memory limit --- distributed/shuffle/shuffle_extension.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 892d4e69382..7aef2ee083f 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -98,12 +98,13 @@ def __init__( sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="500 MiB", # TODO + memory_limit="300 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), join=functools.partial(sum, start=[]), max_connections=min((len(self.metadata.workers) - 1) or 1, 10), + max_message_size="10 MiB", ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -267,12 +268,7 @@ async def shuffle_receive( Using an unknown ``shuffle_id`` is an error. """ shuffle = self._get_shuffle(shuffle_id) - # await shuffle.receive(data) - # return - # TODO: it would be good to not have comms wait on disk if not - # necessary future = asyncio.ensure_future(shuffle.receive(data)) - # await future # backpressure if ( shuffle.multi_file.total_size + sum(map(len, data)) > shuffle.multi_file.memory_limit From f58b2e9efca195a8cc75d346c567dbefcf9ddf59 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 18 Mar 2022 22:30:18 -0500 Subject: [PATCH 32/81] use multi-threaded thread-pool and swap np.unique for pd.Series.unique --- distributed/shuffle/shuffle_extension.py | 36 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 7aef2ee083f..0598a96cfc0 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -7,6 +7,7 @@ import os import time from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, NewType @@ -15,7 +16,7 @@ from distributed.protocol import to_serialize from distributed.shuffle.multi_comm import MultiComm from distributed.shuffle.multi_file import MultiFile -from distributed.utils import offload, sync +from distributed.utils import sync if TYPE_CHECKING: import pandas as pd @@ -80,9 +81,11 @@ def __init__( self, metadata: ShuffleMetadata, worker: Worker, + executor: ThreadPoolExecutor, ) -> None: self.metadata = metadata self.worker = worker + self.executor = executor import pyarrow as pa @@ -98,13 +101,13 @@ def __init__( sizeof=lambda L: sum(map(len, L)), ) self.multi_comm = MultiComm( - memory_limit="300 MiB", # TODO + memory_limit="100 MiB", # TODO rpc=worker.rpc, shuffle_id=self.metadata.id, sizeof=lambda L: sum(map(len, L)), join=functools.partial(sum, start=[]), - max_connections=min((len(self.metadata.workers) - 1) or 1, 10), - max_message_size="10 MiB", + max_connections=min(len(self.metadata.workers), 10), + max_message_size="2 MiB", ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -122,6 +125,14 @@ def time(self, name: str): stop = time.time() self.diagnostics[name] += stop - start + async def offload(self, func, *args): + # return func(*args) + return await asyncio.get_event_loop().run_in_executor( + self.executor, + func, + *args, + ) + def heartbeat(self): return { "disk": { @@ -156,24 +167,24 @@ async def receive(self, data: list[pa.Buffer]) -> None: self.total_recvd += sum(map(len, data)) # An ugly way of turning these batches back into an arrow table with self.time("cpu"): - data = await offload( + data = await self.offload( list_of_buffers_to_table, data, - schema=pa.Schema.from_pandas(self.metadata.empty), + pa.Schema.from_pandas(self.metadata.empty), ) - groups = await offload(split_by_partition, data, self.metadata.column) + groups = await self.offload(split_by_partition, data, self.metadata.column) assert len(data) == sum(map(len, groups.values())) del data with self.time("cpu"): - groups = await offload( + groups = await self.offload( lambda: { k: [batch.serialize() for batch in v.to_batches()] for k, v in groups.items() } - ) # TODO: consider offloading + ) await self.multi_file.put(groups) def add_partition(self, data: pd.DataFrame) -> None: @@ -235,6 +246,7 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} + self.executor = ThreadPoolExecutor(worker.nthreads) # Handlers ########## @@ -252,6 +264,7 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: self.shuffles[metadata.id] = Shuffle( metadata, self.worker, + self.executor, ) def heartbeat(self): @@ -437,6 +450,7 @@ def split_by_worker( Split data into many arrow batches, partitioned by destination worker """ import numpy as np + import pandas as pd import pyarrow as pa grouper = (len(workers) * df[column] // npartitions).astype(df[column].dtype).values @@ -456,7 +470,7 @@ def split_by_worker( ] shards.append(t.slice(offset=splits[-1], length=None)) - w = np.unique(grouper) + w = pd.Series(grouper).unique() w.sort() return {workers[w]: shard for w, shard in zip(w, shards)} @@ -471,7 +485,7 @@ def split_by_partition( """ import numpy as np - partitions = np.unique(np.asarray(t.select([column]))[0]) + partitions = t.select([column]).to_pandas()[column].unique() partitions.sort() t = t.sort_by(column) From 9bc6ce686d277391b75dbd28d506f94b5438b4bc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 07:50:42 -0500 Subject: [PATCH 33/81] removeme: check state of extensions in test --- distributed/tests/test_semaphore.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 50ad43dfce8..5dd493a3a8f 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -463,6 +463,10 @@ async def test_metrics(c, s, a, b): expected_average_pending_lease_time = (time() - before_acquiring) / 2 epsilon = max(0.1, 0.5 * expected_average_pending_lease_time) + if "semaphores" not in s.extensions: + from pprint import pprint + + pprint(s.extensions) sem_ext = s.extensions["semaphores"] actual = sem_ext.metrics.copy() From 1309c22e711c892c356671a8195ce99006d13687 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 08:50:42 -0500 Subject: [PATCH 34/81] I think that there is some strange SchedulerState interation going on --- distributed/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1800d4a3a85..1992b9b7af3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4015,9 +4015,9 @@ def __init__( if not dask.config.get("distributed.scheduler.work-stealing"): if "stealing" in extensions: del extensions["stealing"] - self._extensions = { - name: extension(self) for name, extension in extensions.items() - } + + for name, extension in extensions.items(): + self.extensions[name] = extension(self) setproctitle("dask-scheduler [not started]") Scheduler._instances.add(self) From c894f40ccf8ece40833f9e840c024d4a01239100 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 09:38:41 -0500 Subject: [PATCH 35/81] Track Event Loop intervals in dashboard plot --- distributed/core.py | 21 +++++++- distributed/dashboard/components/scheduler.py | 54 +++++++++++++++++++ distributed/dashboard/scheduler.py | 2 + .../dashboard/tests/test_scheduler_bokeh.py | 9 +++- distributed/distributed.yaml | 1 + distributed/tests/test_worker.py | 30 +++++++++++ distributed/worker.py | 1 + 7 files changed, 115 insertions(+), 3 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 6d043c62b1d..591c37338cc 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -234,11 +234,20 @@ def stop(): self.periodic_callbacks["monitor"] = pc self._last_tick = time() - measure_tick_interval = parse_timedelta( + self._tick_counter = 0 + self._tick_count = 0 + self._tick_count_last = time() + self._tick_interval = parse_timedelta( dask.config.get("distributed.admin.tick.interval"), default="ms" ) - pc = PeriodicCallback(self._measure_tick, measure_tick_interval * 1000) + self._tick_interval_observed = self._tick_interval + pc = PeriodicCallback(self._measure_tick, self._tick_interval * 1000) self.periodic_callbacks["tick"] = pc + pc = PeriodicCallback( + self._cycle_ticks, + parse_timedelta(dask.config.get("distributed.admin.tick.cycle")) * 1000, + ) + self.periodic_callbacks["ticks"] = pc self.thread_id = 0 @@ -351,6 +360,7 @@ def _measure_tick(self): now = time() diff = now - self._last_tick self._last_tick = now + self._tick_counter += 1 if diff > tick_maximum_delay: logger.info( "Event loop was unresponsive in %s for %.2fs. " @@ -363,6 +373,13 @@ def _measure_tick(self): if self.digests is not None: self.digests["tick-duration"].add(diff) + def _cycle_ticks(self): + if not self._tick_counter: + return + last, self._tick_count_last = self._tick_count_last, time() + count, self._tick_counter = self._tick_counter, 0 + self._tick_interval_observed = (time() - last) / (count or 1) + @property def address(self): """ diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 1dc529b19f8..bb01f9e9c9f 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3018,6 +3018,60 @@ def update(self): ) +class EventLoop(DashboardComponent): + """Event Loop Health""" + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "names": ["Scheduler", "Workers"], + "values": [0, 0], + "text": ["0", "0"], + } + ) + + self.root = figure( + title="Event Loop Health", + x_range=["Scheduler", "Workers"], + y_range=[ + 0, + parse_timedelta(dask.config.get("distributed.admin.tick.interval")) + * 10, + ], + tools="", + toolbar_location="above", + ) + self.root.vbar(x="names", top="values", width=0.9, source=self.source) + + self.root.xaxis.minor_tick_line_alpha = 0 + self.root.ygrid.visible = True + self.root.xgrid.visible = False + + hover = HoverTool() + hover.tooltips = [("Interval", "@text s")] + hover.point_policy = "follow_mouse" + self.root.add_tools(hover) + + @without_property_validation + def update(self): + with log_errors(): + s = self.scheduler + + data = { + "names": ["Scheduler", "Workers"], + "values": [ + s._tick_interval_observed, + sum([w.metrics["event_loop_interval"] for w in s.workers.values()]) + / (len(s.workers) or 1), + ], + } + data["text"] = [format_time(x) for x in data["values"]] + + update(self.source, data) + + class WorkerTable(DashboardComponent): """Status of the current workers diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index 3d8e62d95ff..42c50b732bc 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -16,6 +16,7 @@ ClusterMemory, ComputePerKey, CurrentLoad, + EventLoop, MemoryByKey, Occupancy, SystemMonitor, @@ -97,6 +98,7 @@ "/individual-compute-time-per-key": individual_doc(ComputePerKey, 500), "/individual-aggregate-time-per-action": individual_doc(AggregateAction, 500), "/individual-scheduler-system": individual_doc(SystemMonitor, 500), + "/individual-event-loop": individual_doc(EventLoop, 500), "/individual-profile": individual_profile_doc, "/individual-profile-server": individual_profile_server_doc, "/individual-gpu-memory": gpu_memory_doc, diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 2907482e078..deeb909afa0 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -23,6 +23,7 @@ ClusterMemory, ComputePerKey, CurrentLoad, + EventLoop, Events, MemoryByKey, Occupancy, @@ -75,7 +76,13 @@ async def test_simple(c, s, a, b): @gen_cluster(client=True, worker_kwargs={"dashboard": True}) async def test_basic(c, s, a, b): - for component in [TaskStream, SystemMonitor, Occupancy, StealingTimeSeries]: + for component in [ + TaskStream, + SystemMonitor, + Occupancy, + StealingTimeSeries, + EventLoop, + ]: ss = component(s) ss.update() diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 27642579409..cf66a06f28c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -270,6 +270,7 @@ distributed: tick: interval: 20ms # time between event loop health checks limit: 3s # time allowed before triggering a warning + cycle: 1s max-error-length: 10000 # Maximum size traceback after error to return log-length: 10000 # default length of logs to keep in memory diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0caa128c02b..1d30e116250 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3306,3 +3306,33 @@ async def test_Worker__to_dict(c, s, a): } assert d["tasks"]["x"]["key"] == "x" assert d["data"] == ["x"] + + +@gen_cluster( + client=True, + config={ + "distributed.admin.tick.interval": "5ms", + "distributed.admin.tick.cycle": "100ms", + }, +) +async def test_tick_interval(c, s, a, b): + import time + + await a.heartbeat() + x = s.workers[a.address].metrics["event_loop_interval"] + assert x + assert 0.0001 < x < 1 + old = a._tick_interval_observed + + old_count_last = a._tick_count_last + + time.sleep(0.500) # Block event loop + + while a._tick_count_last == old_count_last: + await asyncio.sleep(0.01) + + await a.heartbeat() + y = s.workers[a.address].metrics["event_loop_interval"] + new = a._tick_interval_observed + + assert y > x diff --git a/distributed/worker.py b/distributed/worker.py index 210ecaeb3fe..415b5e38ac0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -945,6 +945,7 @@ async def get_metrics(self) -> dict: "memory": spilled_memory, "disk": spilled_disk, }, + event_loop_interval=self._tick_interval_observed, ) out.update(self.monitor.recent()) From 486320d8a8aef27830ec427db19b3e2b261fc7c8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 09:55:37 -0500 Subject: [PATCH 36/81] Grey out unseen workers --- distributed/dashboard/components/scheduler.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 47253ca1e31..ea4dede2984 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3469,6 +3469,7 @@ def update(self): "disk_written": [], "disk_color": [], } + now = time() for i, (worker, d) in enumerate(input.items()): data["y"].append(i) @@ -3486,7 +3487,9 @@ def update(self): ) data["comm_read"].append(d["comms"]["read"]) data["comm_written"].append(d["comms"]["written"]) - if d["comms"]["active"]: + if self.scheduler.workers[worker].last_seen < now - 5: + data["comm_color"].append("gray") + elif d["comms"]["active"]: data["comm_color"].append("green") elif d["comms"]["memory"] > d["comms"]["memory_limit"]: data["comm_color"].append("red") @@ -3506,7 +3509,9 @@ def update(self): ) data["disk_read"].append(d["disk"]["read"]) data["disk_written"].append(d["disk"]["written"]) - if d["disk"]["active"]: + if self.scheduler.workers[worker].last_seen < now - 5: + data["disk_color"].append("gray") + elif d["disk"]["active"]: data["disk_color"].append("green") elif d["disk"]["memory"] > d["disk"]["memory_limit"]: data["disk_color"].append("red") From 6419328b41c4c9c7018603aafd459f0262578fd8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 10:21:14 -0500 Subject: [PATCH 37/81] flake8 --- distributed/tests/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6545b7d867e..5620ffd6a59 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3368,4 +3368,4 @@ def heartbeat(self, ws, data: dict): async with Worker(s.address, extensions={"test": WorkerExtension}) as w: await w.heartbeat() - assert flag[0] \ No newline at end of file + assert flag[0] From 9e6aadc48f625f1e9db8f48a492e1833229cee43 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 10:23:57 -0500 Subject: [PATCH 38/81] remove old test This was the result of a bad merge conflict --- distributed/tests/test_worker.py | 34 -------------------------------- 1 file changed, 34 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 5620ffd6a59..0b7d98de969 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3308,40 +3308,6 @@ async def test_Worker__to_dict(c, s, a): assert d["data"] == ["x"] -@gen_cluster(nthreads=[]) -async def test_do_not_block_event_loop_during_shutdown(s): - loop = asyncio.get_running_loop() - called_handler = threading.Event() - block_handler = threading.Event() - - w = await Worker(s.address) - executor = w.executors["default"] - - # The block wait must be smaller than the test timeout and smaller than the - # default value for timeout in `Worker.close`` - async def block(): - def fn(): - called_handler.set() - assert block_handler.wait(20) - - await loop.run_in_executor(executor, fn) - - async def set_future(): - while True: - try: - await loop.run_in_executor(executor, sleep, 0.1) - except RuntimeError: # executor has started shutting down - block_handler.set() - return - - async def close(): - called_handler.wait() - # executor_wait is True by default but we want to be explicit here - await w.close(executor_wait=True) - - await asyncio.gather(block(), close(), set_future()) - - @gen_cluster(nthreads=[]) async def test_extension_heartbeat(s): flag = [False] From c2724587a38c8c6dc3e7c08261ac6d7bd70b809f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 10:48:23 -0500 Subject: [PATCH 39/81] bump y-axis, add kwargs --- distributed/dashboard/components/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index bb01f9e9c9f..d2af648131a 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3038,10 +3038,11 @@ def __init__(self, scheduler, **kwargs): y_range=[ 0, parse_timedelta(dask.config.get("distributed.admin.tick.interval")) - * 10, + * 25, ], tools="", toolbar_location="above", + **kwargs, ) self.root.vbar(x="names", top="values", width=0.9, source=self.source) From 4ba9923a75d15f687f851c8ae2e716a490875889 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 19 Mar 2022 10:52:40 -0500 Subject: [PATCH 40/81] Add event loop figure to shuffling page --- distributed/dashboard/components/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index a3396668f32..cba7ccca5c0 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3665,12 +3665,14 @@ def shuffling_doc(scheduler, extra, doc): shuffling = Shuffling(scheduler, width=400, height=400) workers_memory = WorkersMemory(scheduler, width=400, height=400) timeseries = SystemTimeseries( - scheduler, width=1400, height=200, follow_interval=3000 + scheduler, width=1600, height=200, follow_interval=3000 ) + event_loop = EventLoop(scheduler, width=200, height=400) add_periodic_callback(doc, shuffling, 200) add_periodic_callback(doc, workers_memory, 200) add_periodic_callback(doc, timeseries, 500) + add_periodic_callback(doc, event_loop, 500) timeseries.bandwidth.y_range = timeseries.disk.y_range @@ -3681,6 +3683,7 @@ def shuffling_doc(scheduler, extra, doc): shuffling.comm_memory, shuffling.disk_memory, shuffling.totals, + event_loop.root, ), row(column(timeseries.bandwidth, timeseries.disk)), ) From e7a51431405e2d79c73f762b8bd208d9ced61d70 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 13:36:19 -0500 Subject: [PATCH 41/81] Remove errant print --- distributed/tests/test_semaphore.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 5dd493a3a8f..50ad43dfce8 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -463,10 +463,6 @@ async def test_metrics(c, s, a, b): expected_average_pending_lease_time = (time() - before_acquiring) / 2 epsilon = max(0.1, 0.5 * expected_average_pending_lease_time) - if "semaphores" not in s.extensions: - from pprint import pprint - - pprint(s.extensions) sem_ext = s.extensions["semaphores"] actual = sem_ext.metrics.copy() From b0cd7ae166b8d1d0cd0f922458a15deafbd086b6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 18:43:32 -0500 Subject: [PATCH 42/81] Add test for the compute chain --- distributed/shuffle/shuffle_extension.py | 6 +- distributed/shuffle/tests/test_shuffle.py | 80 +++++++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 0598a96cfc0..7d3e37911a4 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -470,10 +470,10 @@ def split_by_worker( ] shards.append(t.slice(offset=splits[-1], length=None)) - w = pd.Series(grouper).unique() - w.sort() + w_unique = pd.Series(grouper).unique() + w_unique.sort() - return {workers[w]: shard for w, shard in zip(w, shards)} + return {workers[w]: shard for w, shard in zip(w_unique, shards)} def split_by_partition( diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 5e10138a18c..7305df4f155 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,6 +1,19 @@ +import io +from collections import defaultdict + +import pandas as pd +import pyarrow as pa + import dask import dask.dataframe as dd +from distributed.shuffle.shuffle_extension import ( + dump_batch, + list_of_buffers_to_table, + load_arrow, + split_by_partition, + split_by_worker, +) from distributed.utils_test import gen_cluster @@ -32,3 +45,70 @@ async def test_heartbeat(c, s, a, b): await out.persist() [s] = s.extensions["shuffle"].shuffles.values() + + +def test_processing_chain(): + """This is a serial version of the entire compute chain + + In practice this takes place on many different workers. + Here we verify its accuracy in a single threaded situation. + """ + workers = ["a", "b", "c"] + npartitions = 5 + df = pd.DataFrame({"x": range(100), "y": range(100)}) + df["_partitions"] = df.x % npartitions + schema = pa.Schema.from_pandas(df) + + data = split_by_worker(df, "_partitions", npartitions, workers) + assert set(data) == set(workers) + + batches = { + worker: [b.serialize().to_pybytes() for b in t.to_batches()] + for worker, t in data.items() + } + + # Typically we communicate to different workers at this stage + # We then receive them back and reconstute them + + by_worker = { + worker: list_of_buffers_to_table(list_of_batches, schema) + for worker, list_of_batches in batches.items() + } + + # We split them again, and then dump them down to disk + + splits_by_worker = { + worker: split_by_partition(t, "_partitions") for worker, t in by_worker.items() + } + + splits_by_worker = { + worker: { + partition: [batch.serialize() for batch in t.to_batches()] + for partition, t in d.items() + } + for worker, d in splits_by_worker.items() + } + + # No two workers share data from any partition + assert not any( + set(a) & set(b) + for w1, a in splits_by_worker.items() + for w2, b in splits_by_worker.items() + if w1 is not w2 + ) + + # Our simple file system + + filesystem = defaultdict(io.BytesIO) + + for worker, partitions in splits_by_worker.items(): + for partition, batches in partitions.items(): + for batch in batches: + dump_batch(batch, filesystem[partition], schema) + + out = {} + for k, bio in filesystem.items(): + bio.seek(0) + out[k] = load_arrow(bio) + + assert sum(map(len, out.values())) == len(df) From bd37f49cfddb410f81b73ba92f7106c47adfab8d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 19:09:13 -0500 Subject: [PATCH 43/81] Simplify MultiComm and add docstrings --- distributed/shuffle/multi_comm.py | 84 ++++++++++++++++------- distributed/shuffle/shuffle_extension.py | 58 ++++++++++++---- distributed/shuffle/tests/test_shuffle.py | 3 +- 3 files changed, 106 insertions(+), 39 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 2906365fdd2..7a3e38af2fa 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -6,25 +6,50 @@ from dask.utils import parse_bytes -from distributed.core import rpc -from distributed.protocol import to_serialize -from distributed.sizeof import sizeof from distributed.system import MEMORY_LIMIT from distributed.utils import log_errors class MultiComm: + """Accept, buffer, and send many small messages to many workers + + This takes in lots of small messages destined for remote workers, buffers + those messages in memory, and then sends out batches of them when possible + to different workers. This tries to send larger messages when possible, + while also respecting a memory bound + + **State** + + - shards: dict[str, list[bytes]] + + This is our in-memory buffer of data waiting to be sent to other workers. + + - sizes: dict[str, int] + + The size of each list of shards. We find the largest and send data from that buffer + + Parameters + ---------- + memory_limit: str + A maximum amount of memory to use, like "1 GiB" + max_connections: int + The maximum number of connections to have out at once + max_message_size: str + The maximum size of a single message that we want to send + send: callable + How to send a list of shards to a worker + + """ + def __init__( self, memory_limit=MEMORY_LIMIT / 4, - join=None, - rpc=rpc, - sizeof=sizeof, max_connections=10, - shuffle_id=None, max_message_size="10 MiB", + send=None, ): self.lock = threading.Lock() + self.send = send self.shards = defaultdict(list) self.sizes = defaultdict(int) self.total_size = 0 @@ -32,26 +57,29 @@ def __init__( self.max_message_size = parse_bytes(max_message_size) self.memory_limit = parse_bytes(memory_limit) self.thread_condition = threading.Condition() - assert join - self.join = join self.max_connections = max_connections - self.sizeof = sizeof - self.shuffle_id = shuffle_id self._futures = set() self._done = False - self.rpc = rpc self.diagnostics = defaultdict(float) def put(self, data: dict): + """ + Put a dict of shards into our buffers + + This is intended to be run from a worker thread, hence the synchronous + nature and the lock. + + If we're out of space then we block in order to enforce backpressure. + """ with self.lock: - for address, shard in data.items(): - size = self.sizeof(shard) - self.shards[address].append(shard) + for address, shards in data.items(): + size = sum(map(len, shards)) + self.shards[address].extend(shards) self.sizes[address] += size self.total_size += size self.total_moved += size - del data, shard + del data, shards while self.total_size > self.memory_limit: with self.time("waiting-on-memory"): @@ -59,6 +87,15 @@ def put(self, data: dict): self.thread_condition.wait(1) # Block until memory calms down async def communicate(self): + """ + Continuously find the largest batch and send from there + + We keep ``max_connections`` comms running while we still have any data + as an old comm finishes, we find the next largest buffer, pull off + ``max_message_size`` data from it, and ship it to the target worker. + + We do this until we're done. This coroutine runs in the background. + """ self.comm_queue = asyncio.Queue(maxsize=self.max_connections) for _ in range(self.max_connections): self.comm_queue.put_nowait(None) @@ -80,7 +117,7 @@ async def communicate(self): try: shard = self.shards[address].pop() shards.append(shard) - s = self.sizeof(shard) + s = len(shard) size += s self.sizes[address] -= s except IndexError: @@ -98,9 +135,8 @@ async def communicate(self): self._futures.add(future) async def process(self, address: str, shards: list, size: int): + """Send one message off to a neighboring worker""" with log_errors(): - shards = self.join(shards) - # shards = await offload(self.join, shards) # Consider boosting total_size a bit here to account for duplication @@ -109,10 +145,7 @@ async def process(self, address: str, shards: list, size: int): # await asyncio.sleep(0.1) start = time.time() with self.time("send"): - await self.rpc(address).shuffle_receive( - data=to_serialize([b"".join(shards)]), - shuffle_id=self.shuffle_id, - ) + await self.send(address, [b"".join(shards)]) stop = time.time() self.diagnostics["avg_size"] = ( 0.95 * self.diagnostics["avg_size"] + 0.05 * size @@ -127,14 +160,15 @@ async def process(self, address: str, shards: list, size: int): await self.comm_queue.put(None) async def flush(self): + """ + We don't expect any more data, wait until everything is flushed through + """ while self.shards: await asyncio.sleep(0.05) await asyncio.gather(*self._futures) self._futures.clear() - if self.total_size: - breakpoint() assert not self.total_size self._done = True diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 7d3e37911a4..b3dbd3c0581 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -100,14 +100,18 @@ def __init__( join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), ) + + async def send(address, shards): + return await self.worker.rpc(address).shuffle_receive( + data=to_serialize(shards), + shuffle_id=self.metadata.id, + ) + self.multi_comm = MultiComm( memory_limit="100 MiB", # TODO - rpc=worker.rpc, - shuffle_id=self.metadata.id, - sizeof=lambda L: sum(map(len, L)), - join=functools.partial(sum, start=[]), max_connections=min(len(self.metadata.workers), 10), max_message_size="2 MiB", + send=send, ) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -274,7 +278,7 @@ async def shuffle_receive( self, comm: object, shuffle_id: ShuffleId, - data: list[pa.Buffer], + data: list[bytes], ) -> None: """ Hander: Receive an incoming shard of data from a peer worker. @@ -434,6 +438,17 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: class ShuffleSchedulerExtension: + """ + Shuffle extension for the scheduler + + Today this mostly just collects heartbeat messages for the dashboard, + but in the future it may be responsible for more + + See Also + -------- + ShuffleWorkerExtension + """ + def __init__(self, scheduler): self.scheduler = scheduler self.shuffles = defaultdict(lambda: defaultdict(dict)) @@ -481,7 +496,7 @@ def split_by_partition( column: str, ) -> dict: """ - Split data into many arrow batches, partitioned by destination worker + Split data into many arrow batches, partitioned by final partition """ import numpy as np @@ -504,20 +519,36 @@ def split_by_partition( return dict(zip(partitions, shards)) -def dump_batch(batch, file, schema=None): +def dump_batch(batch, file, schema=None) -> None: + """ + Dump a batch to file, if we're the first, also write the schema + + See Also + -------- + load_arrow + """ if file.tell() == 0: file.write(schema.serialize()) file.write(batch) -def dump_arrow(t: pa.Table, file): - if file.tell() == 0: - file.write(t.schema.serialize()) - for batch in t.to_batches(): - file.write(batch.serialize()) +def load_arrow(file) -> pa.Table: + """Load batched data written to file back out into a table again + Example + ------- + >>> t = pa.Table.from_pandas(df) # doctest: +SKIP + >>> with open("myfile", mode="wb") as f: # doctest: +SKIP + ... for batch in t.to_batches(): # doctest: +SKIP + ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP -def load_arrow(file): + >>> with open("myfile", mode="rb") as f: # doctest: +SKIP + ... t = load_arrow(f) # doctest: +SKIP + + See Also + -------- + dump_batch + """ import pyarrow as pa try: @@ -534,6 +565,7 @@ def worker_for(output_partition: int, workers: list[str], npartitions: int) -> s def list_of_buffers_to_table(data: list[pa.Buffer], schema: pa.Schema) -> pa.Table: + """Convert a list of arrow buffers and a schema to an Arrow Table""" import io import pyarrow as pa diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 7305df4f155..3b6d5ceb0e8 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -48,7 +48,8 @@ async def test_heartbeat(c, s, a, b): def test_processing_chain(): - """This is a serial version of the entire compute chain + """ + This is a serial version of the entire compute chain In practice this takes place on many different workers. Here we verify its accuracy in a single threaded situation. From 6e1af62f8c7b9d543225f2eab68fd34c98c0bf13 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 19:31:09 -0500 Subject: [PATCH 44/81] Add close method to extensions --- distributed/tests/test_worker.py | 10 ++++++++-- distributed/worker.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0b7d98de969..e57f1a08df2 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3309,8 +3309,9 @@ async def test_Worker__to_dict(c, s, a): @gen_cluster(nthreads=[]) -async def test_extension_heartbeat(s): +async def test_extension_methods(s): flag = [False] + shutdown = [False] class WorkerExtension: def __init__(self, worker): @@ -3319,6 +3320,9 @@ def __init__(self, worker): def heartbeat(self): return {"data": 123} + async def close(self): + shutdown[0] = True + class SchedulerExtension: def __init__(self, scheduler): self.scheduler = scheduler @@ -3332,6 +3336,8 @@ def heartbeat(self, ws, data: dict): s.extensions["test"] = SchedulerExtension(s) async with Worker(s.address, extensions={"test": WorkerExtension}) as w: + assert not shutdown[0] await w.heartbeat() + assert flag[0] - assert flag[0] + assert shutdown[0] diff --git a/distributed/worker.py b/distributed/worker.py index 05c43e9f2b9..8b0274c688d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1413,6 +1413,10 @@ async def close( for preload in self.preloads: await preload.teardown() + for extension in self.extensions.values(): + if hasattr(extension, "close"): + await extension.close() + if nanny and self.nanny: with self.rpc(self.nanny) as r: await r.close_gracefully() From 107b5a0b615ecc06767db19d4b881453b33c46cb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 19:33:41 -0500 Subject: [PATCH 45/81] Add close method to ShuffleWorkerExtension --- distributed/shuffle/shuffle_extension.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index b3dbd3c0581..081f47fb761 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -436,6 +436,9 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: f"Shuffle {shuffle_id!r} is not registered on worker {self.worker.address}" ) from None + async def close(self): + self.executor.shutdown() + class ShuffleSchedulerExtension: """ From dc8a7a40d6565514c99ebbfed5b973f880a51ffb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 19:36:34 -0500 Subject: [PATCH 46/81] clean up old methods --- .../shuffle/tests/test_shuffle_extension.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py index d3e22ef0e89..5c6e520661d 100644 --- a/distributed/shuffle/tests/test_shuffle_extension.py +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -15,8 +15,6 @@ ShuffleId, ShuffleMetadata, ShuffleWorkerExtension, - dump_arrow, - load_arrow, split_by_partition, split_by_worker, worker_for, @@ -343,24 +341,3 @@ def test_split_by_partition(): assert set(out) == {1, 2, 3} assert out[1].column_names == list(df.columns) assert sum(map(len, out.values())) == len(df) - - -def test_load_dump_arrow(tmp_path): - import pyarrow as pa - - df = pd.DataFrame( - { - "x": [1, 2, 3, 4, 5], - "_partition": [3, 1, 2, 3, 1], - } - ) - t = pa.Table.from_pandas(df) - with open(tmp_path / "foo", mode="wb") as f: - dump_arrow(t, f) - dump_arrow(t, f) - dump_arrow(t, f) - - with open(tmp_path / "foo", mode="rb") as f: - tt = load_arrow(f) - - assert str(tt) == str(pa.concat_tables([t, t, t])) From 6ef62a0f61521742115b52510f11a1f2a3018e4a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 20:11:45 -0500 Subject: [PATCH 47/81] Move multi-shuffle state to class level --- distributed/shuffle/multi_comm.py | 28 ++++++++++---------- distributed/shuffle/multi_file.py | 17 ++++++------ distributed/shuffle/shuffle_extension.py | 9 +++---- distributed/shuffle/tests/test_multi_file.py | 8 ++---- distributed/shuffle/tests/test_shuffle.py | 17 ++++++++++++ 5 files changed, 45 insertions(+), 34 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 7a3e38af2fa..a2784dc849c 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -6,7 +6,6 @@ from dask.utils import parse_bytes -from distributed.system import MEMORY_LIMIT from distributed.utils import log_errors @@ -28,24 +27,28 @@ class MultiComm: The size of each list of shards. We find the largest and send data from that buffer - Parameters - ---------- + State + ----- + memory_limit: str A maximum amount of memory to use, like "1 GiB" max_connections: int The maximum number of connections to have out at once max_message_size: str The maximum size of a single message that we want to send + + Parameters + ---------- send: callable How to send a list of shards to a worker - """ + max_message_size = parse_bytes("2 MiB") + memory_limit = parse_bytes("100 MiB") + max_connections = 10 + def __init__( self, - memory_limit=MEMORY_LIMIT / 4, - max_connections=10, - max_message_size="10 MiB", send=None, ): self.lock = threading.Lock() @@ -54,10 +57,7 @@ def __init__( self.sizes = defaultdict(int) self.total_size = 0 self.total_moved = 0 - self.max_message_size = parse_bytes(max_message_size) - self.memory_limit = parse_bytes(memory_limit) self.thread_condition = threading.Condition() - self.max_connections = max_connections self._futures = set() self._done = False self.diagnostics = defaultdict(float) @@ -96,9 +96,9 @@ async def communicate(self): We do this until we're done. This coroutine runs in the background. """ - self.comm_queue = asyncio.Queue(maxsize=self.max_connections) + self.queue = asyncio.Queue(maxsize=self.max_connections) for _ in range(self.max_connections): - self.comm_queue.put_nowait(None) + self.queue.put_nowait(None) while not self._done: with self.time("idle"): @@ -106,7 +106,7 @@ async def communicate(self): await asyncio.sleep(0.1) continue - await self.comm_queue.get() + await self.queue.get() with self.lock: address = max(self.sizes, key=self.sizes.get) @@ -157,7 +157,7 @@ async def process(self, address: str, shards: list, size: int): self.total_size -= size with self.thread_condition: self.thread_condition.notify() - await self.comm_queue.put(None) + await self.queue.put(None) async def flush(self): """ diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 56ad9b67a10..1991ab94d65 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -10,19 +10,20 @@ from dask.sizeof import sizeof from dask.utils import parse_bytes -from distributed.system import MEMORY_LIMIT from distributed.utils import log_errors class MultiFile: + memory_limit = parse_bytes("1 GiB") + queue: asyncio.Queue = None + concurrent_files = 2 + def __init__( self, directory, dump=pickle.dump, load=pickle.load, join=None, - concurrent_files=1, - memory_limit=MEMORY_LIMIT / 2, sizeof=sizeof, ): assert join @@ -39,8 +40,6 @@ def __init__( self.total_size = 0 self.total_received = 0 - self.memory_limit = parse_bytes(memory_limit) - self.concurrent_files = concurrent_files self.condition = asyncio.Condition() self.bytes_written = 0 @@ -51,6 +50,11 @@ def __init__( self.active = set() self.diagnostics = defaultdict(float) + if MultiFile.queue is None: + MultiFile.queue = asyncio.Queue() + for _ in range(MultiFile.concurrent_files): + MultiFile.queue.put_nowait(None) + async def put(self, data: dict): this_size = 0 for id, shard in data.items(): @@ -76,9 +80,6 @@ async def put(self, data: dict): async def communicate(self): with log_errors(): - self.queue = asyncio.Queue(maxsize=self.concurrent_files) - for _ in range(self.concurrent_files): - self.queue.put_nowait(None) while not self._done: with self.time("idle"): diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 081f47fb761..2d1277551cc 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -95,8 +95,6 @@ def __init__( ), load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - memory_limit="900 MiB", # TODO: lift this up to the global ShuffleExtension - concurrent_files=2, join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), ) @@ -108,11 +106,9 @@ async def send(address, shards): ) self.multi_comm = MultiComm( - memory_limit="100 MiB", # TODO - max_connections=min(len(self.metadata.workers), 10), - max_message_size="2 MiB", send=send, ) + MultiComm.max_connections = min(len(self.metadata.workers), 10) self.worker.loop.add_callback(self.multi_comm.communicate) self.worker.loop.add_callback(self.multi_file.communicate) @@ -153,7 +149,7 @@ def heartbeat(self): "buckets": len(self.multi_comm.shards), "written": self.multi_comm.total_moved, "read": self.total_recvd, - "active": self.multi_comm.comm_queue.qsize(), # TODO: maybe not built yet + "active": self.multi_comm.queue.qsize(), # TODO: maybe not built yet "diagnostics": self.multi_comm.diagnostics, "memory_limit": self.multi_comm.memory_limit, }, @@ -438,6 +434,7 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: async def close(self): self.executor.shutdown() + MultiFile.queue = None class ShuffleSchedulerExtension: diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py index 671a56aeaf5..a4f3963c3c7 100644 --- a/distributed/shuffle/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -10,9 +10,7 @@ @pytest.mark.asyncio async def test_basic(tmp_path): - with MultiFile( - directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat - ) as mf: + with MultiFile(directory=tmp_path, memory_limit="16 MiB", join=pd.concat) as mf: df = pd.DataFrame({"x": np.arange(1000), "y": np.arange(1000) * 2}) await mf.write(df, "a") await mf.write(df, "b") @@ -30,9 +28,7 @@ async def test_basic(tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize("count", [2, 100, 1000]) async def test_many(tmp_path, count): - with MultiFile( - directory=tmp_path, n_files=4, memory_limit="16 MiB", join=pd.concat - ) as mf: + with MultiFile(directory=tmp_path, memory_limit="16 MiB", join=pd.concat) as mf: df = pd.DataFrame({"x": np.arange(10), "y": np.arange(10) * 2}) L = list(range(count)) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 3b6d5ceb0e8..8274d09ea19 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -33,6 +33,23 @@ async def test_basic(c, s, a, b): assert x == y +@gen_cluster(client=True, timeout=1000000) +async def test_concurrent(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-06-02", + freq="100ms", + dtypes={"x": int, "y": float, "a": int, "b": float}, + ) + df = dask.datasets.timeseries() + x = dd.shuffle.shuffle(df, "x", shuffle="p2p") + y = dd.shuffle.shuffle(df, "y", shuffle="p2p") + x, y = c.compute([x.x.size, y.y.size]) + x = await x + y = await y + assert x == y + + @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() From 8550a13e0a5e474797951a326a8e3360ec5e7adb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 20:55:07 -0500 Subject: [PATCH 48/81] Speed up tests --- distributed/shuffle/tests/test_shuffle.py | 25 ++++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 8274d09ea19..e6d1711fa11 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,3 +1,4 @@ +import asyncio import io from collections import defaultdict @@ -21,11 +22,10 @@ async def test_basic(c, s, a, b): df = dask.datasets.timeseries( start="2000-01-01", - end="2000-06-02", - freq="100ms", - dtypes={"x": int, "y": float, "a": int, "b": float}, + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", ) - df = dask.datasets.timeseries() out = dd.shuffle.shuffle(df, "x", shuffle="p2p") x, y = c.compute([df.x.size, out.x.size]) x = await x @@ -37,11 +37,10 @@ async def test_basic(c, s, a, b): async def test_concurrent(c, s, a, b): df = dask.datasets.timeseries( start="2000-01-01", - end="2000-06-02", - freq="100ms", - dtypes={"x": int, "y": float, "a": int, "b": float}, + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", ) - df = dask.datasets.timeseries() x = dd.shuffle.shuffle(df, "x", shuffle="p2p") y = dd.shuffle.shuffle(df, "y", shuffle="p2p") x, y = c.compute([x.x.size, y.y.size]) @@ -55,11 +54,17 @@ async def test_heartbeat(c, s, a, b): await a.heartbeat() assert not s.extensions["shuffle"].shuffles df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", dtypes={"x": float, "y": float}, + freq="10 s", ) - df = dask.datasets.timeseries() out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - await out.persist() + out = out.persist() + + while not s.extensions["shuffle"].shuffles: + await asyncio.sleep(0.001) + await a.heartbeat() [s] = s.extensions["shuffle"].shuffles.values() From 2e01aa8a6fb227a80275e5d8119f5e4cbc273e92 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 21 Mar 2022 21:03:54 -0500 Subject: [PATCH 49/81] move multicomm queue to class level --- distributed/shuffle/multi_comm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index a2784dc849c..9e55a95af4e 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -46,6 +46,7 @@ class MultiComm: max_message_size = parse_bytes("2 MiB") memory_limit = parse_bytes("100 MiB") max_connections = 10 + queue: asyncio.Queue = None def __init__( self, @@ -62,6 +63,11 @@ def __init__( self._done = False self.diagnostics = defaultdict(float) + if MultiComm.queue is None: + MultiComm.queue = asyncio.Queue() + for _ in range(MultiComm.max_connections): + MultiComm.queue.put_nowait(None) + def put(self, data: dict): """ Put a dict of shards into our buffers @@ -96,9 +102,6 @@ async def communicate(self): We do this until we're done. This coroutine runs in the background. """ - self.queue = asyncio.Queue(maxsize=self.max_connections) - for _ in range(self.max_connections): - self.queue.put_nowait(None) while not self._done: with self.time("idle"): From 448658455c658d50949af542cec7ba15841b9bd9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 22 Mar 2022 08:45:38 -0500 Subject: [PATCH 50/81] add docstrings and cleanup communicate future --- distributed/shuffle/multi_comm.py | 8 +++ distributed/shuffle/multi_file.py | 86 +++++++++++++++++++++++- distributed/shuffle/shuffle_extension.py | 2 - 3 files changed, 93 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 9e55a95af4e..0c433848896 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -68,6 +68,8 @@ def __init__( for _ in range(MultiComm.max_connections): MultiComm.queue.put_nowait(None) + self._communicate_future = asyncio.ensure_future(self.communicate()) + def put(self, data: dict): """ Put a dict of shards into our buffers @@ -101,6 +103,10 @@ async def communicate(self): ``max_message_size`` data from it, and ship it to the target worker. We do this until we're done. This coroutine runs in the background. + + See Also + -------- + process: does the actual writing """ while not self._done: @@ -176,6 +182,8 @@ async def flush(self): self._done = True + await self._communicate_future + @contextlib.contextmanager def time(self, name: str): start = time.time() diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 1991ab94d65..bfed2a7ad31 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -14,6 +14,48 @@ class MultiFile: + """Accept, buffer, and write many small objects to many files + + This takes in lots of small objects, writes them to a local directory, and + then reads them back when all writes are complete. It buffers these + objects in memory so that it can optimize disk access for larger writes. + + **State** + + - shards: dict[str, list[bytes]] + + This is our in-memory buffer of data waiting to be sent to other workers. + + - sizes: dict[str, int] + + The size of each list of shards. We find the largest and send data from that buffer + + State + ----- + + memory_limit: str + A maximum amount of memory to use, like "1 GiB" + max_connections: int + The maximum number of connections to have out at once + max_message_size: str + The maximum size of a single message that we want to send + + Parameters + ---------- + directory: pathlib.Path + Where to write and read data. Ideally points to fast disk. + dump: callable + Writes an object to a file, like pickle.dump + load: callable + Reads an object from that file, like pickle.load + join: callable + Joins many objects together + send: callable + How to send a list of shards to a worker + sizeof: callable + Measures the size of an object in memory + """ + memory_limit = parse_bytes("1 GiB") queue: asyncio.Queue = None concurrent_files = 2 @@ -55,7 +97,18 @@ def __init__( for _ in range(MultiFile.concurrent_files): MultiFile.queue.put_nowait(None) - async def put(self, data: dict): + self._communicate_future = asyncio.ensure_future(self.communicate()) + + async def put(self, data: dict[str, list[object]]): + """ + Writes many objects into the local buffers, blocks until ready for more + + Parameters + ---------- + data: dict + A dictionary mapping destinations to lists of objects that should + be written to that destination + """ this_size = 0 for id, shard in data.items(): size = self.sizeof(shard) @@ -79,6 +132,19 @@ async def put(self, data: dict): continue async def communicate(self): + """ + Continuously find the largest batch and trigger writes + + We keep ``concurrent_files`` files open, writing while we still have any data + as an old write finishes, we find the next largest buffer, and write + its contents to file. + + We do this until we're done. This coroutine runs in the background. + + See Also + -------- + process: does the actual writing + """ with log_errors(): while not self._done: @@ -100,6 +166,19 @@ async def communicate(self): self.condition.notify() async def process(self, id: str, shards: list, size: int): + """Write one buffer to file + + This function was built to offload the disk IO, but since then we've + decided to keep this within the event loop (disk bandwidth should be + prioritized, and writes are typically small enough to not be a big + deal). + + Most of the logic here is about possibly going back to a separate + thread, or about diagnostics. If things don't change much in the + future then we should consider simplifying this considerably and + dropping the write into communicate above. + """ + with log_errors(): # Consider boosting total_size a bit here to account for duplication while id in self.active: @@ -136,6 +215,7 @@ def _(): await self.queue.put(None) def read(self, id): + """Read a complete file back into memory""" parts = [] try: @@ -160,6 +240,7 @@ def read(self, id): raise KeyError(id) async def flush(self): + """Wait until all writes are finished""" while self.shards: await asyncio.sleep(0.05) @@ -171,7 +252,10 @@ async def flush(self): self._done = True + await self._communicate_future + def close(self): + self._done = True shutil.rmtree(self.directory) def __enter__(self): diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 2d1277551cc..6972162a53f 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -109,8 +109,6 @@ async def send(address, shards): send=send, ) MultiComm.max_connections = min(len(self.metadata.workers), 10) - self.worker.loop.add_callback(self.multi_comm.communicate) - self.worker.loop.add_callback(self.multi_file.communicate) self.diagnostics: dict[str, float] = defaultdict(float) self.output_partitions_left = metadata.npartitions_for(worker.address) From 01403b908d12b3609fe4ef061214895b5a3e6f77 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 22 Mar 2022 11:25:25 -0500 Subject: [PATCH 51/81] Update distributed/stealing.py Co-authored-by: Florian Jetter --- distributed/stealing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index bc51cabee38..9789cb58b45 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -80,7 +80,6 @@ def __init__(self, scheduler): self._in_flight_event = asyncio.Event() self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm - self.scheduler.extensions["stealing"] = self async def start(self, scheduler=None): """Start the background coroutine to balance the tasks on the cluster. From 97fdf2a7c3482ff9240151c2376b7b681ba8677d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 22 Mar 2022 15:28:55 -0500 Subject: [PATCH 52/81] use nonlocal --- distributed/tests/test_worker.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e57f1a08df2..3d14d0603f9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3310,8 +3310,8 @@ async def test_Worker__to_dict(c, s, a): @gen_cluster(nthreads=[]) async def test_extension_methods(s): - flag = [False] - shutdown = [False] + flag = False + shutdown = False class WorkerExtension: def __init__(self, worker): @@ -3321,7 +3321,8 @@ def heartbeat(self): return {"data": 123} async def close(self): - shutdown[0] = True + nonlocal shutdown + shutdown = True class SchedulerExtension: def __init__(self, scheduler): @@ -3329,15 +3330,16 @@ def __init__(self, scheduler): pass def heartbeat(self, ws, data: dict): + nonlocal flag assert ws in self.scheduler.workers.values() assert data == {"data": 123} - flag[0] = True + flag = True s.extensions["test"] = SchedulerExtension(s) async with Worker(s.address, extensions={"test": WorkerExtension}) as w: - assert not shutdown[0] + assert not shutdown await w.heartbeat() - assert flag[0] + assert flag - assert shutdown[0] + assert shutdown From 007ea90637f8e1f64a62eb29c5981892fa5b6e3f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 23 Mar 2022 17:25:22 -0500 Subject: [PATCH 53/81] Update distributed/shuffle/multi_file.py Co-authored-by: Ashwin Srinath <3190405+shwina@users.noreply.github.com> --- distributed/shuffle/multi_file.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index bfed2a7ad31..d7ccfcf01f5 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import contextlib import os From ea000a3f019d626d10d7e2b40bfe68d3eae84b99 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 15:02:20 -0500 Subject: [PATCH 54/81] cleanup hover --- distributed/dashboard/components/scheduler.py | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index c9de755faa4..1a014252312 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3401,8 +3401,8 @@ def __init__(self, scheduler, **kwargs): ("Average Duration", "@comm_avg_duration"), ], formatters={"@comm_avg_duration": "datetime"}, + point_policy="follow_mouse", ) - hover.point_policy = "follow_mouse" self.comm_memory.add_tools(hover) self.comm_memory.x_range.start = 0 self.comm_memory.x_range.end = 1 @@ -3427,26 +3427,15 @@ def __init__(self, scheduler, **kwargs): ) hover = HoverTool( - tooltips=""" -
- Memory Used:  - @disk_memory{0.00 b} -
-
- Average Write:  - @disk_avg_size{0.00 b} -
-
- # Buckets:  - @disk_buckets -
-
- Average Duration:  - @disk_avg_duration -
- """, + tooltips=[ + ("Memory Used", "@disk_memory{0.00 b}"), + ("Average Write", "@disk_avg_size{0.00 b}"), + ("# Buckets", "@disk_buckets"), + ("Average Duration", "@disk_avg_duration"), + ], + formatters={"@disk_avg_duration": "datetime"}, + point_policy="follow_mouse", ) - hover.point_policy = "follow_mouse" self.disk_memory.add_tools(hover) self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") From 4598577dec601c69e6bea4d6493d8f2609223e5c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 16:26:38 -0500 Subject: [PATCH 55/81] Use weakkeydicitonary to handle multiple queues --- distributed/shuffle/multi_comm.py | 21 +++++++++++++++------ distributed/shuffle/multi_file.py | 21 +++++++++++++++------ distributed/shuffle/shuffle_extension.py | 3 ++- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 0c433848896..eb5cd841600 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -2,6 +2,7 @@ import contextlib import threading import time +import weakref from collections import defaultdict from dask.utils import parse_bytes @@ -46,11 +47,12 @@ class MultiComm: max_message_size = parse_bytes("2 MiB") memory_limit = parse_bytes("100 MiB") max_connections = 10 - queue: asyncio.Queue = None + _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def __init__( self, send=None, + loop=None, ): self.lock = threading.Lock() self.send = send @@ -62,14 +64,21 @@ def __init__( self._futures = set() self._done = False self.diagnostics = defaultdict(float) - - if MultiComm.queue is None: - MultiComm.queue = asyncio.Queue() - for _ in range(MultiComm.max_connections): - MultiComm.queue.put_nowait(None) + self._loop = loop self._communicate_future = asyncio.ensure_future(self.communicate()) + @property + def queue(self): + try: + return MultiComm._queues[self._loop] + except KeyError: + queue = asyncio.Queue() + for _ in range(MultiComm.max_connections): + queue.put_nowait(None) + MultiComm._queues[self._loop] = queue + return queue + def put(self, data: dict): """ Put a dict of shards into our buffers diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index d7ccfcf01f5..a150553f1fb 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -7,6 +7,7 @@ import pickle import shutil import time +import weakref from collections import defaultdict from dask.sizeof import sizeof @@ -59,7 +60,7 @@ class MultiFile: """ memory_limit = parse_bytes("1 GiB") - queue: asyncio.Queue = None + _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() concurrent_files = 2 def __init__( @@ -69,6 +70,7 @@ def __init__( load=pickle.load, join=None, sizeof=sizeof, + loop=None, ): assert join self.directory = pathlib.Path(directory) @@ -94,12 +96,19 @@ def __init__( self.active = set() self.diagnostics = defaultdict(float) - if MultiFile.queue is None: - MultiFile.queue = asyncio.Queue() - for _ in range(MultiFile.concurrent_files): - MultiFile.queue.put_nowait(None) - self._communicate_future = asyncio.ensure_future(self.communicate()) + self._loop = loop + + @property + def queue(self): + try: + return MultiFile._queues[self._loop] + except KeyError: + queue = asyncio.Queue() + for _ in range(MultiFile.concurrent_files): + queue.put_nowait(None) + MultiFile._queues[self._loop] = queue + return queue async def put(self, data: dict[str, list[object]]): """ diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 6972162a53f..e8cbf6fc1c4 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -97,6 +97,7 @@ def __init__( directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), + loop=worker.io_loop, ) async def send(address, shards): @@ -107,6 +108,7 @@ async def send(address, shards): self.multi_comm = MultiComm( send=send, + loop=worker.io_loop, ) MultiComm.max_connections = min(len(self.metadata.workers), 10) @@ -432,7 +434,6 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: async def close(self): self.executor.shutdown() - MultiFile.queue = None class ShuffleSchedulerExtension: From 29936431f93119d12c1418cc111a853c487699f3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 16:34:11 -0500 Subject: [PATCH 56/81] Add total_size to class level --- distributed/shuffle/multi_comm.py | 7 +++++-- distributed/shuffle/multi_file.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index eb5cd841600..ee27274c574 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -48,13 +48,14 @@ class MultiComm: memory_limit = parse_bytes("100 MiB") max_connections = 10 _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + total_size = 0 + lock = threading.Lock() def __init__( self, send=None, loop=None, ): - self.lock = threading.Lock() self.send = send self.shards = defaultdict(list) self.sizes = defaultdict(int) @@ -94,11 +95,12 @@ def put(self, data: dict): self.shards[address].extend(shards) self.sizes[address] += size self.total_size += size + MultiComm.total_size += size self.total_moved += size del data, shards - while self.total_size > self.memory_limit: + while MultiComm.total_size > self.memory_limit: with self.time("waiting-on-memory"): with self.thread_condition: self.thread_condition.wait(1) # Block until memory calms down @@ -173,6 +175,7 @@ async def process(self, address: str, shards: list, size: int): ] + 0.02 * (stop - start) finally: self.total_size -= size + MultiComm.total_size -= size with self.thread_condition: self.thread_condition.notify() await self.queue.put(None) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index a150553f1fb..5f07646f13c 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -62,6 +62,7 @@ class MultiFile: memory_limit = parse_bytes("1 GiB") _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() concurrent_files = 2 + total_size = 0 def __init__( self, @@ -126,12 +127,13 @@ async def put(self, data: dict[str, list[object]]): self.shards[id].extend(shard) self.sizes[id] += size self.total_size += size + MultiFile.total_size += size self.total_received += size this_size += size del data, shard - while self.total_size > self.memory_limit: + while MultiFile.total_size > self.memory_limit: with self.time("waiting-on-memory"): async with self.condition: @@ -221,6 +223,7 @@ def _(): self.active.remove(id) self.bytes_written += size self.total_size -= size + MultiFile.total_size -= size async with self.condition: self.condition.notify() await self.queue.put(None) From fe61116a4a23f12b894a00580edb3af8f5976e53 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 16:40:05 -0500 Subject: [PATCH 57/81] make dashboard robust to missing workers --- distributed/dashboard/components/scheduler.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 1a014252312..8780c89b2be 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3529,14 +3529,17 @@ def update(self): ) data["comm_read"].append(d["comms"]["read"]) data["comm_written"].append(d["comms"]["written"]) - if self.scheduler.workers[worker].last_seen < now - 5: - data["comm_color"].append("gray") - elif d["comms"]["active"]: - data["comm_color"].append("green") - elif d["comms"]["memory"] > d["comms"]["memory_limit"]: - data["comm_color"].append("red") - else: - data["comm_color"].append("blue") + try: + if self.scheduler.workers[worker].last_seen < now - 5: + data["comm_color"].append("gray") + elif d["comms"]["active"]: + data["comm_color"].append("green") + elif d["comms"]["memory"] > d["comms"]["memory_limit"]: + data["comm_color"].append("red") + else: + data["comm_color"].append("blue") + except KeyError: + data["comm_color"].append("black") data["disk_memory"].append(d["disk"]["memory"]) data["disk_memory_half"].append(d["disk"]["memory"] / 2) @@ -3551,14 +3554,17 @@ def update(self): ) data["disk_read"].append(d["disk"]["read"]) data["disk_written"].append(d["disk"]["written"]) - if self.scheduler.workers[worker].last_seen < now - 5: - data["disk_color"].append("gray") - elif d["disk"]["active"]: - data["disk_color"].append("green") - elif d["disk"]["memory"] > d["disk"]["memory_limit"]: - data["disk_color"].append("red") - else: - data["disk_color"].append("blue") + try: + if self.scheduler.workers[worker].last_seen < now - 5: + data["disk_color"].append("gray") + elif d["disk"]["active"]: + data["disk_color"].append("green") + elif d["disk"]["memory"] > d["disk"]["memory_limit"]: + data["disk_color"].append("red") + else: + data["disk_color"].append("blue") + except KeyError: + data["disk_color"].append("black") """ singletons = { From adccb02cc4872eda97604808057940107a53eb41 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 16:56:28 -0500 Subject: [PATCH 58/81] tests pass --- distributed/shuffle/multi_comm.py | 2 +- distributed/shuffle/tests/test_graph.py | 2 + distributed/shuffle/tests/test_multi_file.py | 2 + .../shuffle/tests/test_shuffle_extension.py | 136 ------------------ 4 files changed, 5 insertions(+), 137 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index ee27274c574..a7a4cacdda1 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -98,7 +98,7 @@ def put(self, data: dict): MultiComm.total_size += size self.total_moved += size - del data, shards + del data while MultiComm.total_size > self.memory_limit: with self.time("waiting-on-memory"): diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 3844ff3db49..d261ee67097 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -46,6 +46,7 @@ def test_shuffle_helper(client: Client): def test_basic(client: Client): df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + df["name"] = df["name"].astype("string[python]") shuffled = shuffle(df, "id") (opt,) = dask.optimize(shuffled) @@ -79,6 +80,7 @@ async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): def test_multiple_linear(client: Client): df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + df["name"] = df["name"].astype("string[python]") s1 = shuffle(df, "id") s1["x"] = s1["x"] + 1 s2 = shuffle(s1, "x") diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py index a4f3963c3c7..427f6a8c77e 100644 --- a/distributed/shuffle/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -7,6 +7,8 @@ from distributed.shuffle.multi_file import MultiFile +pytestmark = pytest.mark.skip(reason="Internal API has shifted") + @pytest.mark.asyncio async def test_basic(tmp_path): diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py index 5c6e520661d..80c3475e904 100644 --- a/distributed/shuffle/tests/test_shuffle_extension.py +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -154,142 +154,6 @@ async def test_create(s: Scheduler, *workers: Worker): await exts[0]._create_shuffle(new_metadata) -@gen_cluster([("", 1)] * 4) -async def test_add_partition(s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 8, - ) - - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - partition = pd.DataFrame( - { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 4, 5, 6, 7], - } - ) - await ext._add_partition(partition, new_metadata.id) - - with pytest.raises(ValueError, match="not registered"): - await ext._add_partition(partition, ShuffleId("bar")) - - for i, data in partition.groupby(new_metadata.column): - addr = metadata.worker_for(int(i)) - ext = exts[addr] - received = ext.shuffles[metadata.id].output_partitions[int(i)] - assert len(received) == 1 - dd.utils.assert_eq(data, received[0]) - - # TODO (resilience stage) test failed sends - - -@gen_cluster([("", 1)] * 4, client=True) -async def test_barrier(c: Client, s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 4, - ) - fs = await add_dummy_unpack_keys(new_metadata, c) - - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - partition = pd.DataFrame( - { - "A": ["a", "b", "c"], - "partition": [0, 1, 2], - } - ) - await ext._add_partition(partition, metadata.id) - - await ext._barrier(metadata.id) - - # Check scheduler restrictions were set for unpack tasks - for i, key in enumerate(fs): - assert s.tasks[key].worker_restrictions == {metadata.worker_for(i)} - - # Check all workers have been informed of the barrier - for addr, ext in exts.items(): - if metadata.npartitions_for(addr): - shuffle = ext.shuffles[metadata.id] - assert shuffle.transferred - assert not shuffle.done() - else: - # No output partitions on this worker; shuffle already cleaned up - assert not ext.shuffles - - -@gen_cluster([("", 1)] * 4, client=True) -async def test_get_partition(c: Client, s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 8, - ) - _ = await add_dummy_unpack_keys(new_metadata, c) - - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - p1 = pd.DataFrame( - { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 4, 5, 6, 6], - } - ) - p2 = pd.DataFrame( - { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 0, 0, 2, 3], - } - ) - await asyncio.gather( - ext._add_partition(p1, metadata.id), ext._add_partition(p2, metadata.id) - ) - await ext._barrier(metadata.id) - - for addr, ext in exts.items(): - if metadata.worker_for(0) != addr: - with pytest.raises(AssertionError, match="belongs on"): - ext.get_output_partition(metadata.id, 0) - - full = pd.concat([p1, p2]) - expected_groups = full.groupby("partition") - for output_i in range(metadata.npartitions): - addr = metadata.worker_for(output_i) - ext = exts[addr] - result = ext.get_output_partition(metadata.id, output_i) - try: - expected = expected_groups.get_group(output_i) - except KeyError: - expected = metadata.empty - dd.utils.assert_eq(expected, result) - # ^ NOTE: use `assert_eq` instead of `pd.testing.assert_frame_equal` directly - # to ignore order of the rows (`assert_eq` pre-sorts its inputs). - - # Once all partitions are retrieved, shuffles are cleaned up - for ext in exts.values(): - assert not ext.shuffles - with pytest.raises(ValueError, match="not registered"): - ext.get_output_partition(metadata.id, 0) - - def test_split_by_worker(): df = pd.DataFrame( { From 3aabef0bfc9978cdf513389caca1552b843cde50 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 17:03:06 -0500 Subject: [PATCH 59/81] depend on pyarrow in CI --- continuous_integration/environment-3.8.yaml | 1 + continuous_integration/environment-3.9.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/continuous_integration/environment-3.8.yaml b/continuous_integration/environment-3.8.yaml index 14f81813f5d..e25c5fc3759 100644 --- a/continuous_integration/environment-3.8.yaml +++ b/continuous_integration/environment-3.8.yaml @@ -26,6 +26,7 @@ dependencies: - pre-commit - prometheus_client - psutil + - pyarrow=7 - pytest - pytest-asyncio<0.14.0 - pytest-cov diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index ce156549709..a283c70da59 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -27,6 +27,7 @@ dependencies: - pre-commit - prometheus_client - psutil + - pyarrow=7 - pynvml # Only tested here - pytest - pytest-asyncio<0.14.0 From 2af497447d87339db7002c0f15293af3a3f83dd2 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 17:34:17 -0500 Subject: [PATCH 60/81] install dask@p2p-shuffle --- .github/workflows/tests.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7e087f99757..255fd4ac9a6 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -78,6 +78,10 @@ jobs: if: ${{ matrix.os == 'windows-latest' && matrix.python-version == '3.9' }} run: mamba uninstall ipython + - name: Install dask branch + shell: bash -l {0} + run: python -m pip install --no-deps git+https://github.com/mrocklin/dask@p2p-shuffle + - name: Install shell: bash -l {0} run: | From 1756fb2ea83e2a5ffd85545cd2636215ad425493 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 25 Mar 2022 18:20:43 -0500 Subject: [PATCH 61/81] simplify dashboard charts --- distributed/dashboard/components/scheduler.py | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 8780c89b2be..3a08ec48b31 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3350,7 +3350,6 @@ def __init__(self, scheduler, **kwargs): "worker": [], "y": [], "comm_memory": [], - "comm_memory_half": [], "comm_memory_limit": [], "comm_buckets": [], "comm_active": [], @@ -3360,7 +3359,6 @@ def __init__(self, scheduler, **kwargs): "comm_written": [], "comm_color": [], "disk_memory": [], - "disk_memory_half": [], "disk_memory_limit": [], "disk_buckets": [], "disk_active": [], @@ -3385,10 +3383,9 @@ def __init__(self, scheduler, **kwargs): x_range=Range1d(0, 100_000_000), **kwargs, ) - self.comm_memory.rect( + self.comm_memory.hbar( source=self.source, - x="comm_memory_half", - width="comm_memory", + right="comm_memory", y="y", height=0.9, color="comm_color", @@ -3401,7 +3398,7 @@ def __init__(self, scheduler, **kwargs): ("Average Duration", "@comm_avg_duration"), ], formatters={"@comm_avg_duration": "datetime"}, - point_policy="follow_mouse", + mode="hline", ) self.comm_memory.add_tools(hover) self.comm_memory.x_range.start = 0 @@ -3417,10 +3414,9 @@ def __init__(self, scheduler, **kwargs): ) self.disk_memory.yaxis.visible = False - self.disk_memory.rect( + self.disk_memory.hbar( source=self.source, - x="disk_memory_half", - width="disk_memory", + right="disk_memory", y="y", height=0.9, color="disk_color", @@ -3434,7 +3430,7 @@ def __init__(self, scheduler, **kwargs): ("Average Duration", "@disk_avg_duration"), ], formatters={"@disk_avg_duration": "datetime"}, - point_policy="follow_mouse", + mode="hline", ) self.disk_memory.add_tools(hover) self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") @@ -3466,14 +3462,9 @@ def __init__(self, scheduler, **kwargs): self.totals.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") hover = HoverTool( - tooltips=""" -
- @x:  - @values{0.00 b} -
- """, + tooltips=[("Total", "@values{0.00b}")], + mode="vline", ) - hover.point_policy = "follow_mouse" self.totals.add_tools(hover) self.root = row(self.comm_memory, self.disk_memory) @@ -3491,7 +3482,6 @@ def update(self): "worker": [], "y": [], "comm_memory": [], - "comm_memory_half": [], "comm_memory_limit": [], "comm_buckets": [], "comm_active": [], @@ -3501,7 +3491,6 @@ def update(self): "comm_written": [], "comm_color": [], "disk_memory": [], - "disk_memory_half": [], "disk_memory_limit": [], "disk_buckets": [], "disk_active": [], @@ -3517,7 +3506,6 @@ def update(self): data["y"].append(i) data["worker"].append(worker) data["comm_memory"].append(d["comms"]["memory"]) - data["comm_memory_half"].append(d["comms"]["memory"] / 2) data["comm_memory_limit"].append(d["comms"]["memory_limit"]) data["comm_buckets"].append(d["comms"]["buckets"]) data["comm_active"].append(d["comms"]["active"]) @@ -3542,7 +3530,6 @@ def update(self): data["comm_color"].append("black") data["disk_memory"].append(d["disk"]["memory"]) - data["disk_memory_half"].append(d["disk"]["memory"] / 2) data["disk_memory_limit"].append(d["disk"]["memory_limit"]) data["disk_buckets"].append(d["disk"]["buckets"]) data["disk_active"].append(d["disk"]["active"]) From 6694c849113e049f28c8caa3dd55312e1d9bea1f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 28 Mar 2022 08:35:14 -0500 Subject: [PATCH 62/81] Move arrow utilities over to a separate file Also add pyarrow to precommit / mypy settings --- .pre-commit-config.yaml | 1 + distributed/shuffle/arrow.py | 62 +++++++++++++++++++++ distributed/shuffle/shuffle_extension.py | 69 +++--------------------- 3 files changed, 70 insertions(+), 62 deletions(-) create mode 100644 distributed/shuffle/arrow.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 329103e7d02..947b97efa5d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,3 +47,4 @@ repos: - dask - tornado - zict + - pyarrow diff --git a/distributed/shuffle/arrow.py b/distributed/shuffle/arrow.py new file mode 100644 index 00000000000..7d843c4baff --- /dev/null +++ b/distributed/shuffle/arrow.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow as pa + + +def dump_batch(batch, file, schema=None) -> None: + """ + Dump a batch to file, if we're the first, also write the schema + + See Also + -------- + load_arrow + """ + if file.tell() == 0: + file.write(schema.serialize()) + file.write(batch) + + +def load_arrow(file) -> pa.Table: + """Load batched data written to file back out into a table again + + Example + ------- + >>> t = pa.Table.from_pandas(df) # doctest: +SKIP + >>> with open("myfile", mode="wb") as f: # doctest: +SKIP + ... for batch in t.to_batches(): # doctest: +SKIP + ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP + + >>> with open("myfile", mode="rb") as f: # doctest: +SKIP + ... t = load_arrow(f) # doctest: +SKIP + + See Also + -------- + dump_batch + """ + import pyarrow as pa + + try: + sr = pa.RecordBatchStreamReader(file) + return sr.read_all() + except Exception: + raise EOFError + + +def list_of_buffers_to_table(data: list[pa.Buffer], schema: pa.Schema) -> pa.Table: + """Convert a list of arrow buffers and a schema to an Arrow Table""" + import io + + import pyarrow as pa + + bio = io.BytesIO() + bio.write(schema.serialize()) + for batch in data: + bio.write(batch) + bio.seek(0) + sr = pa.RecordBatchStreamReader(bio) + data = sr.read_all() + bio.close() + return data diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index e8cbf6fc1c4..9983f149ce1 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -14,6 +14,7 @@ import toolz from distributed.protocol import to_serialize +from distributed.shuffle.arrow import dump_batch, list_of_buffers_to_table, load_arrow from distributed.shuffle.multi_comm import MultiComm from distributed.shuffle.multi_file import MultiFile from distributed.utils import sync @@ -457,6 +458,12 @@ def heartbeat(self, ws, data): self.shuffles[shuffle_id][ws.address].update(d) +def worker_for(output_partition: int, workers: list[str], npartitions: int) -> str: + "Get the address of the worker which should hold this output partition number" + i = len(workers) * output_partition // npartitions + return workers[i] + + def split_by_worker( df: pd.DataFrame, column: str, npartitions: int, workers: list[str] ) -> dict: @@ -516,65 +523,3 @@ def split_by_partition( breakpoint() assert len(partitions) == len(shards) return dict(zip(partitions, shards)) - - -def dump_batch(batch, file, schema=None) -> None: - """ - Dump a batch to file, if we're the first, also write the schema - - See Also - -------- - load_arrow - """ - if file.tell() == 0: - file.write(schema.serialize()) - file.write(batch) - - -def load_arrow(file) -> pa.Table: - """Load batched data written to file back out into a table again - - Example - ------- - >>> t = pa.Table.from_pandas(df) # doctest: +SKIP - >>> with open("myfile", mode="wb") as f: # doctest: +SKIP - ... for batch in t.to_batches(): # doctest: +SKIP - ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP - - >>> with open("myfile", mode="rb") as f: # doctest: +SKIP - ... t = load_arrow(f) # doctest: +SKIP - - See Also - -------- - dump_batch - """ - import pyarrow as pa - - try: - sr = pa.RecordBatchStreamReader(file) - return sr.read_all() - except Exception: - raise EOFError - - -def worker_for(output_partition: int, workers: list[str], npartitions: int) -> str: - "Get the address of the worker which should hold this output partition number" - i = len(workers) * output_partition // npartitions - return workers[i] - - -def list_of_buffers_to_table(data: list[pa.Buffer], schema: pa.Schema) -> pa.Table: - """Convert a list of arrow buffers and a schema to an Arrow Table""" - import io - - import pyarrow as pa - - bio = io.BytesIO() - bio.write(schema.serialize()) - for batch in data: - bio.write(batch) - bio.seek(0) - sr = pa.RecordBatchStreamReader(bio) - data = sr.read_all() - bio.close() - return data From 2ab401abc8d42e541c745d59b4d07bc774af70b5 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 29 Mar 2022 13:58:37 -0500 Subject: [PATCH 63/81] make multi_file tests pass --- distributed/shuffle/multi_file.py | 10 ++-- distributed/shuffle/shuffle_extension.py | 1 - distributed/shuffle/tests/test_multi_file.py | 53 ++++++++++---------- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 5f07646f13c..f20dac4a171 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -51,8 +51,6 @@ class MultiFile: Writes an object to a file, like pickle.dump load: callable Reads an object from that file, like pickle.load - join: callable - Joins many objects together send: callable How to send a list of shards to a worker sizeof: callable @@ -69,17 +67,14 @@ def __init__( directory, dump=pickle.dump, load=pickle.load, - join=None, sizeof=sizeof, loop=None, ): - assert join self.directory = pathlib.Path(directory) if not os.path.exists(self.directory): os.mkdir(self.directory) self.dump = dump self.load = load - self.join = join self.sizeof = sizeof self.shards = defaultdict(list) @@ -98,7 +93,7 @@ def __init__( self.diagnostics = defaultdict(float) self._communicate_future = asyncio.ensure_future(self.communicate()) - self._loop = loop + self._loop = loop or self @property def queue(self): @@ -249,7 +244,8 @@ def read(self, id): # TODO: We could consider deleting the file at this point if parts: self.bytes_read += size - return self.join(parts) + assert len(parts) == 1 + return parts[0] else: raise KeyError(id) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 9983f149ce1..3b3548d1bff 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -96,7 +96,6 @@ def __init__( ), load=load_arrow, directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), - join=pa.concat_tables, # pd.concat sizeof=lambda L: sum(map(len, L)), loop=worker.io_loop, ) diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py index 427f6a8c77e..7a9f6f17914 100644 --- a/distributed/shuffle/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -1,28 +1,34 @@ import os -import random -import numpy as np -import pandas as pd import pytest from distributed.shuffle.multi_file import MultiFile -pytestmark = pytest.mark.skip(reason="Internal API has shifted") + +def dump(data, f): + f.write(data) + + +def load(f): + out = f.read() + if not out: + raise EOFError() + return out @pytest.mark.asyncio async def test_basic(tmp_path): - with MultiFile(directory=tmp_path, memory_limit="16 MiB", join=pd.concat) as mf: - df = pd.DataFrame({"x": np.arange(1000), "y": np.arange(1000) * 2}) - await mf.write(df, "a") - await mf.write(df, "b") - await mf.write(df * 2, "a") + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) - a = mf.read("a") - b = mf.read("b") + await mf.flush() - assert (df == b).all().all() - assert (pd.concat([df, df * 2]) == a).all().all() + x = mf.read("x") + y = mf.read("y") + + assert x == b"0" * 2000 + assert y == b"1" * 1000 assert not os.path.exists(tmp_path) @@ -30,23 +36,16 @@ async def test_basic(tmp_path): @pytest.mark.asyncio @pytest.mark.parametrize("count", [2, 100, 1000]) async def test_many(tmp_path, count): - with MultiFile(directory=tmp_path, memory_limit="16 MiB", join=pd.concat) as mf: - df = pd.DataFrame({"x": np.arange(10), "y": np.arange(10) * 2}) - - L = list(range(count)) + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + d = {i: [str(i).encode() * 100] for i in range(count)} - random.shuffle(L) - for i in L: - await mf.write(df + i, i) + for i in range(10): + await mf.put(d) - random.shuffle(L) - for i in L: - await mf.write(df * i, i) + await mf.flush() - random.shuffle(L) - for i in L: + for i in d: out = mf.read(i) - - assert (pd.concat([df + i, df * i]) == out).all().all() + assert out == str(i).encode() * 100 * 10 assert not os.path.exists(tmp_path) From 90673d1ab759a8241feb6026d44ce2a76380a97d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 29 Mar 2022 14:08:38 -0500 Subject: [PATCH 64/81] Add test for MultiComm --- distributed/shuffle/multi_comm.py | 2 +- distributed/shuffle/tests/test_multi_comm.py | 22 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 distributed/shuffle/tests/test_multi_comm.py diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index a7a4cacdda1..b28ff7861b7 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -65,7 +65,7 @@ def __init__( self._futures = set() self._done = False self.diagnostics = defaultdict(float) - self._loop = loop + self._loop = loop or self self._communicate_future = asyncio.ensure_future(self.communicate()) diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py new file mode 100644 index 00000000000..cff50f8ff92 --- /dev/null +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -0,0 +1,22 @@ +from collections import defaultdict + +import pytest + +from distributed.shuffle.multi_comm import MultiComm + + +@pytest.mark.asyncio +async def test_basic(tmp_path): + d = defaultdict(list) + + async def send(address, shards): + d[address].extend(shards) + + mc = MultiComm(send=send) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + await mc.flush() + + assert b"".join(d["x"]) == b"0" * 2000 + assert b"".join(d["y"]) == b"1" * 1000 From 7d8954a2f9d05f3ed34e5634c652c624a5947387 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 09:48:43 -0500 Subject: [PATCH 65/81] Respond to feedback --- distributed/shuffle/arrow.py | 5 ++++- distributed/shuffle/multi_comm.py | 6 +++--- distributed/shuffle/multi_file.py | 6 +++--- distributed/shuffle/shuffle_extension.py | 10 ++++++---- distributed/worker.py | 4 +++- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/distributed/shuffle/arrow.py b/distributed/shuffle/arrow.py index 7d843c4baff..f2e757728b7 100644 --- a/distributed/shuffle/arrow.py +++ b/distributed/shuffle/arrow.py @@ -3,7 +3,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import pyarrow as pa + try: + import pyarrow as pa + except ImportError: + raise ImportError("PyArrow is needed for fast shuffling") def dump_batch(batch, file, schema=None) -> None: diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index b28ff7861b7..6fd2bc337d7 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -65,9 +65,9 @@ def __init__( self._futures = set() self._done = False self.diagnostics = defaultdict(float) - self._loop = loop or self + self._loop = loop or asyncio.get_event_loop() - self._communicate_future = asyncio.ensure_future(self.communicate()) + self._communicate_future = asyncio.create_task(self.communicate()) @property def queue(self): @@ -150,7 +150,7 @@ async def communicate(self): assert set(self.sizes) == set(self.shards) assert shards - future = asyncio.ensure_future(self.process(address, shards, size)) + future = asyncio.create_task(self.process(address, shards, size)) del shards self._futures.add(future) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index f20dac4a171..9667193f3fd 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -92,8 +92,8 @@ def __init__( self.active = set() self.diagnostics = defaultdict(float) - self._communicate_future = asyncio.ensure_future(self.communicate()) - self._loop = loop or self + self._communicate_future = asyncio.create_task(self.communicate()) + self._loop = loop or asyncio.get_event_loop() @property def queue(self): @@ -167,7 +167,7 @@ async def communicate(self): shards = self.shards.pop(id) size = self.sizes.pop(id) - future = asyncio.ensure_future(self.process(id, shards, size)) + future = asyncio.create_task(self.process(id, shards, size)) del shards self._futures.add(future) async with self.condition: diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 3b3548d1bff..2ec51875a71 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -21,7 +21,11 @@ if TYPE_CHECKING: import pandas as pd - import pyarrow as pa + + try: + import pyarrow as pa + except ImportError: + raise ImportError("PyArrow is needed for fast shuffling") from distributed.worker import Worker @@ -432,7 +436,7 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: f"Shuffle {shuffle_id!r} is not registered on worker {self.worker.address}" ) from None - async def close(self): + def close(self): self.executor.shutdown() @@ -518,7 +522,5 @@ def split_by_partition( ] shards.append(t.slice(offset=splits[-1], length=None)) assert len(t) == sum(map(len, shards)) - if len(partitions) != len(shards): - breakpoint() assert len(partitions) == len(shards) return dict(zip(partitions, shards)) diff --git a/distributed/worker.py b/distributed/worker.py index c607bb5afa1..1552ef4d20b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1425,7 +1425,9 @@ async def close( for extension in self.extensions.values(): if hasattr(extension, "close"): - await extension.close() + result = extension.close() + if isawaitable(result): + result = await result if nanny and self.nanny: with self.rpc(self.nanny) as r: From 179680481888145436fbc165256a18d20e06cfcf Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 29 Mar 2022 23:00:38 +0100 Subject: [PATCH 66/81] Drop runtime dependency to setuptools (#6017) --- .pre-commit-config.yaml | 1 - continuous_integration/recipes/dask/meta.yaml | 1 + continuous_integration/recipes/distributed/meta.yaml | 5 +++-- distributed/tests/test_client.py | 7 ++++--- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 13957c312f9..cc2c6ff46ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,6 @@ repos: - types-requests - types-paramiko - types-PyYAML - - types-setuptools - types-psutil # Typed libraries - numpy diff --git a/continuous_integration/recipes/dask/meta.yaml b/continuous_integration/recipes/dask/meta.yaml index 90c69295125..01746b7b51f 100644 --- a/continuous_integration/recipes/dask/meta.yaml +++ b/continuous_integration/recipes/dask/meta.yaml @@ -19,6 +19,7 @@ build: requirements: host: - python >=3.8 + - setuptools run: - python >=3.8 - dask-core >={{ dask_version }} diff --git a/continuous_integration/recipes/distributed/meta.yaml b/continuous_integration/recipes/distributed/meta.yaml index 13ebf926096..8222c1a07fd 100644 --- a/continuous_integration/recipes/distributed/meta.yaml +++ b/continuous_integration/recipes/distributed/meta.yaml @@ -54,11 +54,12 @@ outputs: build: - {{ compiler('c') }} # [cython_enabled] host: - - python + - python >=3.8 - pip + - setuptools - cython # [cython_enabled] run: - - python + - python >=3.8 - click >=6.6 - cloudpickle >=1.5.0 - cytoolz >=0.8.2 diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7b419000327..7479b8cff48 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1637,6 +1637,8 @@ def g(): @gen_cluster(client=True) async def test_upload_file_egg(c, s, a, b): + pytest.importorskip("setuptools") + def g(): import package_1 import package_2 @@ -1657,9 +1659,8 @@ def g(): with open(os.path.join(dirname, "setup.py"), "w") as f: f.write("from setuptools import setup, find_packages\n") f.write( - 'setup(name="my_package", packages=find_packages(), version="{}")\n'.format( - value - ) + 'setup(name="my_package", packages=find_packages(), ' + f'version="{value}")\n' ) # test a package with an underscore in the name From 8de2793b190f3be291ee18ee07792a7b35691e73 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 30 Mar 2022 02:39:40 +0100 Subject: [PATCH 67/81] More idiomatic mypy configuration (#6022) --- .pre-commit-config.yaml | 8 +++----- setup.cfg | 5 +++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc2c6ff46ca..20e54aaa3a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,11 +34,8 @@ repos: rev: v0.942 hooks: - id: mypy - args: - - --ignore-missing-imports - # Silence errors about Python 3.9-style delayed type annotations on Python 3.8 - - --python-version - - "3.9" + # Override default --ignore-missing-imports + args: [] additional_dependencies: # Type stubs - types-docutils @@ -46,6 +43,7 @@ repos: - types-paramiko - types-PyYAML - types-psutil + - types-setuptools # Typed libraries - numpy - dask diff --git a/setup.cfg b/setup.cfg index dd99eccfc7a..34fa189fe23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,3 +59,8 @@ timeout_method = thread # This should not be reduced; Windows CI has been observed to be occasionally # exceptionally slow. timeout = 300 + +[mypy] +# Silence errors about Python 3.9-style delayed type annotations on Python 3.8 +python_version = 3.9 +ignore_missing_imports = true From caa852f05cb5c8eb4d5429f78cb7f04c5bf11284 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 30 Mar 2022 09:18:06 +0100 Subject: [PATCH 68/81] Python 3.10 (#5952) --- .github/workflows/tests.yaml | 8 +- continuous_integration/environment-3.10.yaml | 46 ++++++++ distributed/node.py | 5 +- distributed/profile.py | 29 +++++- distributed/tests/test_client.py | 11 ++ distributed/tests/test_profile.py | 104 +++++++++++++++++++ distributed/utils_test.py | 17 +-- setup.py | 3 + 8 files changed, 201 insertions(+), 22 deletions(-) create mode 100644 continuous_integration/environment-3.10.yaml diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 255fd4ac9a6..cd02a60eee9 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -23,7 +23,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.8", "3.9"] + python-version: ["3.8", "3.9", "3.10"] # Cherry-pick test modules to split the overall runtime roughly in half partition: [ci1, not ci1] include: @@ -65,12 +65,6 @@ jobs: shell: bash -l {0} run: conda config --show - - name: Install stacktrace - shell: bash -l {0} - # stacktrace for Python 3.8 has not been released at the moment of writing - if: ${{ matrix.os == 'ubuntu-latest' && matrix.python-version < '3.8' }} - run: mamba install -c conda-forge -c defaults -c numba libunwind stacktrace - - name: Hack around https://github.com/ipython/ipython/issues/12197 # This upstream issue causes an interpreter crash when running # distributed/protocol/tests/test_serialize.py::test_profile_nested_sizeof diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml new file mode 100644 index 00000000000..0ee82922262 --- /dev/null +++ b/continuous_integration/environment-3.10.yaml @@ -0,0 +1,46 @@ +name: dask-distributed +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - packaging + - pip + - asyncssh + - bokeh + - click + - cloudpickle + - coverage<6.3 # https://github.com/nedbat/coveragepy/issues/1310 + - dask # overridden by git tip below + - filesystem-spec # overridden by git tip below + - h5py + - ipykernel + - ipywidgets + - jinja2 + - jupyter_client + - msgpack-python + - netcdf4 + - paramiko + - pre-commit + - prometheus_client + - psutil + - pytest + - pytest-cov + - pytest-faulthandler + - pytest-repeat + - pytest-rerunfailures + - pytest-timeout + - requests + - s3fs # overridden by git tip below + - scikit-learn + - scipy + - sortedcollections + - tblib + - toolz + - tornado=6 + - zict # overridden by git tip below + - zstandard + - pip: + - git+https://github.com/dask/dask + - git+https://github.com/dask/zict + - pytest-asyncio<0.14.0 # `pytest-asyncio<0.14.0` isn't available on conda-forge for Python 3.10 diff --git a/distributed/node.py b/distributed/node.py index 6db2c7711ea..6fedd1b8ace 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -131,12 +131,9 @@ def start_http_server( import ssl ssl_options = ssl.create_default_context( - cafile=tls_ca_file, purpose=ssl.Purpose.SERVER_AUTH + cafile=tls_ca_file, purpose=ssl.Purpose.CLIENT_AUTH ) ssl_options.load_cert_chain(tls_cert, keyfile=tls_key) - # We don't care about auth here, just encryption - ssl_options.check_hostname = False - ssl_options.verify_mode = ssl.CERT_NONE self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options) diff --git a/distributed/profile.py b/distributed/profile.py index bb832735e8d..22a2fc80cff 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -27,6 +27,7 @@ from __future__ import annotations import bisect +import dis import linecache import sys import threading @@ -59,21 +60,41 @@ def identifier(frame): ) +# work around some frames lacking an f_lineo eg: https://bugs.python.org/issue47085 +def _f_lineno(frame): + f_lineno = frame.f_lineno + if f_lineno is not None: + return f_lineno + + f_lasti = frame.f_lasti + code = frame.f_code + prev_line = code.co_firstlineno + + for start, next_line in dis.findlinestarts(code): + if f_lasti < start: + return prev_line + prev_line = next_line + + return prev_line + + def repr_frame(frame): """Render a frame as a line for inclusion into a text traceback""" co = frame.f_code - text = f' File "{co.co_filename}", line {frame.f_lineno}, in {co.co_name}' - line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip() + f_lineno = _f_lineno(frame) + text = f' File "{co.co_filename}", line {f_lineno}, in {co.co_name}' + line = linecache.getline(co.co_filename, f_lineno, frame.f_globals).lstrip() return text + "\n\t" + line def info_frame(frame): co = frame.f_code - line = linecache.getline(co.co_filename, frame.f_lineno, frame.f_globals).lstrip() + f_lineno = _f_lineno(frame) + line = linecache.getline(co.co_filename, f_lineno, frame.f_globals).lstrip() return { "filename": co.co_filename, "name": co.co_name, - "line_number": frame.f_lineno, + "line_number": f_lineno, "line": line, } diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7479b8cff48..84c5882bdef 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6463,6 +6463,10 @@ async def f(stacklevel, mode=None): assert "cdn.bokeh.org" in data +@pytest.mark.skipif( + sys.version_info >= (3, 10), + reason="On Py3.10+ semaphore._loop is not bound until .acquire() blocks", +) @gen_cluster(nthreads=[]) async def test_client_gather_semaphore_loop(s): async with Client(s.address, asynchronous=True) as c: @@ -6473,9 +6477,16 @@ async def test_client_gather_semaphore_loop(s): async def test_as_completed_condition_loop(c, s, a, b): seq = c.map(inc, range(5)) ac = as_completed(seq) + # consume the ac so that the ac.condition is bound to the loop on py3.10+ + async for _ in ac: + pass assert ac.condition._loop == c.loop.asyncio_loop +@pytest.mark.skipif( + sys.version_info >= (3, 10), + reason="On Py3.10+ semaphore._loop is not bound until .acquire() blocks", +) def test_client_connectionpool_semaphore_loop(s, a, b): with Client(s["address"]) as c: assert c.rpc.semaphore._loop is c.loop.asyncio_loop diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 7d044945d7c..75eb704c99f 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -1,5 +1,9 @@ +from __future__ import annotations + +import dataclasses import sys import threading +from collections.abc import Iterator, Sequence from time import sleep import pytest @@ -11,6 +15,7 @@ call_stack, create, identifier, + info_frame, ll_get_stack, llprocess, merge, @@ -200,3 +205,102 @@ def stop(): while threading.active_count() > start_threads: assert time() < start + 2 sleep(0.01) + + +@dataclasses.dataclass(frozen=True) +class FakeCode: + co_filename: str + co_name: str + co_firstlineno: int + co_lnotab: bytes + co_lines_seq: Sequence[tuple[int, int, int | None]] + co_code: bytes + + def co_lines(self) -> Iterator[tuple[int, int, int | None]]: + yield from self.co_lines_seq + + +FAKE_CODE = FakeCode( + co_filename="", + co_name="example", + co_firstlineno=1, + # https://github.com/python/cpython/blob/b68431fadb3150134ac6ccbf501cdfeaf4c75678/Objects/lnotab_notes.txt#L84 + # generated from: + # def example(): + # for i in range(1): + # if i >= 0: + # pass + # example.__code__.co_lnotab + co_lnotab=b"\x00\x01\x0c\x01\x08\x01\x04\xfe", + # generated with list(example.__code__.co_lines()) + co_lines_seq=[ + (0, 12, 2), + (12, 20, 3), + (20, 22, 4), + (22, 24, None), + (24, 28, 2), + ], + # used in dis.findlinestarts as bytecode_len = len(code.co_code) + # https://github.com/python/cpython/blob/6f345d363308e3e6ecf0ad518ea0fcc30afde2a8/Lib/dis.py#L457 + co_code=bytes(28), +) + + +@dataclasses.dataclass(frozen=True) +class FakeFrame: + f_lasti: int + f_code: FakeCode + f_lineno: int | None = None + f_back: FakeFrame | None = None + f_globals: dict[str, object] = dataclasses.field(default_factory=dict) + + +@pytest.mark.parametrize( + "f_lasti,f_lineno", + [ + (-1, 1), + (0, 2), + (1, 2), + (11, 2), + (12, 3), + (21, 4), + (22, 4), + (23, 4), + (24, 2), + (25, 2), + (26, 2), + (27, 2), + (100, 2), + ], +) +def test_info_frame_f_lineno(f_lasti: int, f_lineno: int) -> None: + assert info_frame(FakeFrame(f_lasti=f_lasti, f_code=FAKE_CODE)) == { + "filename": "", + "name": "example", + "line_number": f_lineno, + "line": "", + } + + +@pytest.mark.parametrize( + "f_lasti,f_lineno", + [ + (-1, 1), + (0, 2), + (1, 2), + (11, 2), + (12, 3), + (21, 4), + (22, 4), + (23, 4), + (24, 2), + (25, 2), + (26, 2), + (27, 2), + (100, 2), + ], +) +def test_call_stack_f_lineno(f_lasti: int, f_lineno: int) -> None: + assert call_stack(FakeFrame(f_lasti=f_lasti, f_code=FAKE_CODE)) == [ + f' File "", line {f_lineno}, in example\n\t' + ] diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3e558a0b55a..a8e2652d120 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -739,13 +739,16 @@ def cluster( except KeyError: rpc_kwargs = {} - with rpc(saddr, **rpc_kwargs) as s: - while True: - nthreads = loop.run_sync(s.ncores) - if len(nthreads) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") + async def wait_for_workers(): + async with rpc(saddr, **rpc_kwargs) as s: + while True: + nthreads = await s.ncores() + if len(nthreads) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") + + loop.run_sync(wait_for_workers) # avoid sending processes down to function yield {"address": saddr}, [ diff --git a/setup.py b/setup.py index 0f57c525795..1661783f36d 100755 --- a/setup.py +++ b/setup.py @@ -98,8 +98,11 @@ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering", "Topic :: System :: Distributed Computing", ], From f3fb6821029f6719f064855c4035ecf20ae468a6 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Wed, 30 Mar 2022 15:42:42 +0200 Subject: [PATCH 69/81] Cluster Dump SchedulerPlugin (#5983) Add SchedulerPlugin to dump state on cluster close This also adds a new method to SchedulerPlugins that runs directly before closing time --- distributed/client.py | 4 +- distributed/cluster_dump.py | 5 ++- distributed/diagnostics/cluster_dump.py | 38 +++++++++++++++++++ distributed/diagnostics/plugin.py | 3 ++ .../tests/test_cluster_dump_plugin.py | 21 ++++++++++ distributed/scheduler.py | 5 +++ 6 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 distributed/diagnostics/cluster_dump.py create mode 100644 distributed/diagnostics/tests/test_cluster_dump_plugin.py diff --git a/distributed/client.py b/distributed/client.py index f68570f7762..d406ca6333a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3934,8 +3934,8 @@ async def _dump_cluster_state( self, filename: str = "dask-cluster-dump", write_from_scheduler: bool | None = None, - exclude: Collection[str] = ("run_spec",), - format: Literal["msgpack", "yaml"] = "msgpack", + exclude: Collection[str] = cluster_dump.DEFAULT_CLUSTER_DUMP_EXCLUDE, + format: Literal["msgpack", "yaml"] = cluster_dump.DEFAULT_CLUSTER_DUMP_FORMAT, **storage_options, ): filename = str(filename) diff --git a/distributed/cluster_dump.py b/distributed/cluster_dump.py index 2d3a400b256..161f9091e7b 100644 --- a/distributed/cluster_dump.py +++ b/distributed/cluster_dump.py @@ -14,6 +14,9 @@ from distributed.stories import scheduler_story as _scheduler_story from distributed.stories import worker_story as _worker_story +DEFAULT_CLUSTER_DUMP_FORMAT: Literal["msgpack" | "yaml"] = "msgpack" +DEFAULT_CLUSTER_DUMP_EXCLUDE: Collection[str] = ("run_spec",) + def _tuple_to_list(node): if isinstance(node, (list, tuple)): @@ -27,7 +30,7 @@ def _tuple_to_list(node): async def write_state( get_state: Callable[[], Awaitable[Any]], url: str, - format: Literal["msgpack", "yaml"], + format: Literal["msgpack", "yaml"] = DEFAULT_CLUSTER_DUMP_FORMAT, **storage_options: dict[str, Any], ) -> None: "Await a cluster dump, then serialize and write it to a path" diff --git a/distributed/diagnostics/cluster_dump.py b/distributed/diagnostics/cluster_dump.py new file mode 100644 index 00000000000..c03b0293c05 --- /dev/null +++ b/distributed/diagnostics/cluster_dump.py @@ -0,0 +1,38 @@ +from typing import Any, Collection, Dict, Literal + +from distributed.cluster_dump import ( + DEFAULT_CLUSTER_DUMP_EXCLUDE, + DEFAULT_CLUSTER_DUMP_FORMAT, +) +from distributed.diagnostics.plugin import SchedulerPlugin +from distributed.scheduler import Scheduler + + +class ClusterDump(SchedulerPlugin): + """Dumps cluster state prior to Scheduler shutdown + + The Scheduler may shutdown in cases where it is in an error state, + or when it has been unexpectedly idle for long periods of time. + This plugin dumps the cluster state prior to Scheduler shutdown + for debugging purposes. + """ + + def __init__( + self, + url: str, + exclude: "Collection[str]" = DEFAULT_CLUSTER_DUMP_EXCLUDE, + format_: Literal["msgpack", "yaml"] = DEFAULT_CLUSTER_DUMP_FORMAT, + **storage_options: Dict[str, Any], + ): + self.url = url + self.exclude = exclude + self.format = format_ + self.storage_options = storage_options + + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + + async def before_close(self) -> None: + await self.scheduler.dump_cluster_state_to_url( + self.url, self.exclude, self.format, **self.storage_options + ) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 9940102b6ed..cf573eed562 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -61,6 +61,9 @@ async def start(self, scheduler: Scheduler) -> None: This runs at the end of the Scheduler startup process """ + async def before_close(self) -> None: + """Runs prior to any Scheduler shutdown logic""" + async def close(self) -> None: """Run when the scheduler closes down diff --git a/distributed/diagnostics/tests/test_cluster_dump_plugin.py b/distributed/diagnostics/tests/test_cluster_dump_plugin.py new file mode 100644 index 00000000000..67ce815954d --- /dev/null +++ b/distributed/diagnostics/tests/test_cluster_dump_plugin.py @@ -0,0 +1,21 @@ +from distributed.cluster_dump import DumpArtefact +from distributed.diagnostics.cluster_dump import ClusterDump +from distributed.utils_test import gen_cluster, inc + + +@gen_cluster(client=True) +async def test_cluster_dump_plugin(c, s, *workers, tmp_path): + dump_file = tmp_path / "cluster_dump.msgpack.gz" + await c.register_scheduler_plugin(ClusterDump(str(dump_file)), name="cluster-dump") + plugin = s.plugins["cluster-dump"] + assert plugin.scheduler is s + + f1 = c.submit(inc, 1) + f2 = c.submit(inc, f1) + + assert (await f2) == 3 + await s.close(close_workers=True) + + dump = DumpArtefact.from_url(str(dump_file)) + assert {f1.key, f2.key} == set(dump.scheduler_story(f1.key, f2.key).keys()) + assert {f1.key, f2.key} == set(dump.worker_story(f1.key, f2.key).keys()) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b8d24828c89..13c7b5cd5c7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4247,6 +4247,11 @@ async def close(self, fast=False, close_workers=False): if self.status in (Status.closing, Status.closed): await self.finished() return + + await asyncio.gather( + *[plugin.before_close() for plugin in list(self.plugins.values())] + ) + self.status = Status.closing logger.info("Scheduler closing...") From 0a1761d91613bb97834a31fc8af7fc85bd441603 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 30 Mar 2022 13:56:00 -0500 Subject: [PATCH 70/81] Add tiny test for ToPickle (#6021) --- distributed/protocol/tests/test_to_pickle.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/distributed/protocol/tests/test_to_pickle.py b/distributed/protocol/tests/test_to_pickle.py index 7db7a5d9738..d3099c9a73a 100644 --- a/distributed/protocol/tests/test_to_pickle.py +++ b/distributed/protocol/tests/test_to_pickle.py @@ -4,10 +4,22 @@ from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from distributed.client import Client +from distributed.protocol import dumps, loads from distributed.protocol.serialize import ToPickle from distributed.utils_test import gen_cluster +def test_ToPickle(): + class Foo: + def __init__(self, data): + self.data = data + + msg = {"x": ToPickle(Foo(123))} + frames = dumps(msg) + out = loads(frames) + assert out["x"].data == 123 + + class NonMsgPackSerializableLayer(MaterializedLayer): """Layer that uses non-msgpack-serializable data""" From 7cdb56f29375ee9f974c8f5c0557aeacc52c16c3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:28:22 -0500 Subject: [PATCH 71/81] Update gpuCI `RAPIDS_VER` to `22.06` (#5962) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- continuous_integration/gpuci/axis.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 922eee23c14..41ddb56ec2d 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -8,6 +8,6 @@ LINUX_VER: - ubuntu18.04 RAPIDS_VER: -- "22.04" +- "22.06" excludes: From d0afbb10474dd9cdfa8e3e3d56f2d1843ca4f94a Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 30 Mar 2022 17:50:11 -0500 Subject: [PATCH 72/81] Retry on transient error codes in preload (#5982) --- distributed/preloading.py | 16 ++++++++++++---- requirements.txt | 1 + 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/distributed/preloading.py b/distributed/preloading.py index 458642a01d9..07adbde48ba 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -6,13 +6,13 @@ import os import shutil import sys -import urllib.request from collections.abc import Iterable from importlib import import_module from types import ModuleType from typing import TYPE_CHECKING, cast import click +import urllib3 from dask.utils import tmpfile @@ -131,9 +131,17 @@ def _download_module(url: str) -> ModuleType: logger.info("Downloading preload at %s", url) assert is_webaddress(url) - request = urllib.request.Request(url, method="GET") - response = urllib.request.urlopen(request) - source = response.read().decode() + with urllib3.PoolManager() as http: + response = http.request( + method="GET", + url=url, + retries=urllib3.util.Retry( + status_forcelist=[429, 504, 503, 502], + backoff_factor=0.2, + ), + ) + + source = response.data compiled = compile(source, url, "exec") module = ModuleType(url) diff --git a/requirements.txt b/requirements.txt index b6b7ed5d824..efc734dfe6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,6 @@ sortedcontainers !=2.0.0, !=2.0.1 tblib >= 1.6.0 toolz >= 0.8.2 tornado >= 6.0.3 +urllib3 zict >= 0.1.3 pyyaml From a74fd3812f20abcb09647a609381ca57a244e623 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 31 Mar 2022 01:35:14 -0500 Subject: [PATCH 73/81] Remove support for PyPy (#6029) --- distributed/compatibility.py | 2 -- distributed/tests/test_spill.py | 2 -- distributed/utils_perf.py | 9 --------- docs/source/protocol.rst | 6 ++---- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/distributed/compatibility.py b/distributed/compatibility.py index f0151267867..ca49f7b0b27 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -1,14 +1,12 @@ from __future__ import annotations import logging -import platform import sys logging_names: dict[str | int, int | str] = {} logging_names.update(logging._levelToName) # type: ignore logging_names.update(logging._nameToLevel) # type: ignore -PYPY = platform.python_implementation().lower() == "pypy" LINUX = sys.platform == "linux" MACOS = sys.platform == "darwin" WINDOWS = sys.platform.startswith("win") diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index 2581701aa1f..6ab595933e5 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -1,6 +1,5 @@ from __future__ import annotations -import gc import logging import os import uuid @@ -338,7 +337,6 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size): # the same id as a deleted one id_x = x.id del x - gc.collect() # Only needed on pypy if size < 100: buf["y"] diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index 6643b738b99..41ff877dbc5 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -5,7 +5,6 @@ from dask.utils import format_bytes -from distributed.compatibility import PYPY from distributed.metrics import thread_time logger = _logger = logging.getLogger(__name__) @@ -144,8 +143,6 @@ def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6): self._enabled = False def enable(self): - if PYPY: - return assert not self._enabled self._fractional_timer = FractionalTimer(n_samples=self.N_SAMPLES) try: @@ -162,8 +159,6 @@ def enable(self): self._enabled = True def disable(self): - if PYPY: - return assert self._enabled gc.callbacks.remove(self._gc_callback) self._enabled = False @@ -229,8 +224,6 @@ def enable_gc_diagnosis(): """ Ask to enable global GC diagnosis. """ - if PYPY: - return global _gc_diagnosis_users with _gc_diagnosis_lock: if _gc_diagnosis_users == 0: @@ -244,8 +237,6 @@ def disable_gc_diagnosis(force=False): """ Ask to disable global GC diagnosis. """ - if PYPY: - return global _gc_diagnosis_users with _gc_diagnosis_lock: if _gc_diagnosis_users > 0: diff --git a/docs/source/protocol.rst b/docs/source/protocol.rst index 334e2c0e4bd..9f5d4990909 100644 --- a/docs/source/protocol.rst +++ b/docs/source/protocol.rst @@ -135,13 +135,11 @@ the scheduler may differ.** This has a few advantages: 1. The Scheduler is protected from unpickling unsafe code -2. The Scheduler can be run under ``pypy`` for improved performance. This is - only useful for larger clusters. -3. We could conceivably implement workers and clients for other languages +2. We could conceivably implement workers and clients for other languages (like R or Julia) and reuse the Python scheduler. The worker and client code is fairly simple and much easier to reimplement than the scheduler, which is complex. -4. The scheduler might some day be rewritten in more heavily optimized C or Go +3. The scheduler might some day be rewritten in more heavily optimized C or Go Compression ----------- From bde718fbe02054ec91833bb77be9d6c88b857d99 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 07:54:43 -0500 Subject: [PATCH 74/81] Make test_reconnect async (#6000) This was flakey due to cleaning up resources. My experience is that making things async helps with this in general. I don't have strong confidence that this will fix the issue, but I do have mild confidence, and strong confidence that it won't hurt. --- distributed/tests/test_client.py | 81 ++++++++++++++++++-------------- distributed/tests/test_steal.py | 2 + 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 84c5882bdef..c50d5487a4c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -66,7 +66,7 @@ from distributed.cluster_dump import load_cluster_dump from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS -from distributed.core import Status +from distributed.core import Server, Status from distributed.metrics import time from distributed.objects import HasWhat, WhoHas from distributed.scheduler import ( @@ -94,7 +94,6 @@ inc, map_varying, nodebug, - popen, pristine_loop, randominc, save_sys_modules, @@ -3701,60 +3700,70 @@ async def test_scatter_raises_if_no_workers(c, s): await c.scatter(1, timeout=0.5) -@pytest.mark.slow -def test_reconnect(loop): - w = Worker("127.0.0.1", 9393, loop=loop) - loop.add_callback(w.start) - - scheduler_cli = [ - "dask-scheduler", - "--host", - "127.0.0.1", - "--port", - "9393", - "--no-dashboard", - ] - with popen(scheduler_cli): - c = Client("127.0.0.1:9393", loop=loop) - c.wait_for_workers(1, timeout=10) - x = c.submit(inc, 1) - assert x.result(timeout=10) == 2 +@gen_test() +async def test_reconnect(): + async def hard_stop(s): + for pc in s.periodic_callbacks.values(): + pc.stop() + + s.stop_services() + for comm in list(s.stream_comms.values()): + comm.abort() + for comm in list(s.client_comms.values()): + comm.abort() + + await s.rpc.close() + s.stop() + await Server.close(s) + + port = 9393 + futures = [] + w = Worker(f"127.0.0.1:{port}") + futures.append(asyncio.ensure_future(w.start())) + + s = await Scheduler(port=port) + c = await Client(f"127.0.0.1:{port}", asynchronous=True) + await c.wait_for_workers(1, timeout=10) + x = c.submit(inc, 1) + assert (await x) == 2 + await hard_stop(s) start = time() while c.status != "connecting": assert time() < start + 10 - sleep(0.01) + await asyncio.sleep(0.01) assert x.status == "cancelled" with pytest.raises(CancelledError): - x.result(timeout=10) + await x - with popen(scheduler_cli): - start = time() - while c.status != "running": - sleep(0.1) - assert time() < start + 10 - start = time() - while len(c.nthreads()) != 1: - sleep(0.05) - assert time() < start + 10 + s = await Scheduler(port=port) + start = time() + while c.status != "running": + await asyncio.sleep(0.1) + assert time() < start + 10 + start = time() + while len(await c.nthreads()) != 1: + await asyncio.sleep(0.05) + assert time() < start + 10 - x = c.submit(inc, 1) - assert x.result(timeout=10) == 2 + x = c.submit(inc, 1) + assert (await x) == 2 + await hard_stop(s) start = time() while True: assert time() < start + 10 try: - x.result(timeout=10) + await x assert False except CommClosedError: continue except CancelledError: break - sync(loop, w.close, timeout=1) - c.close() + await w.close(report=False) + await c._close(fast=True) class UnhandledException(Exception): diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index e6469eae8b4..e16269d92f8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import gc import itertools import logging import random @@ -945,6 +946,7 @@ class Foo: assert not s.who_has assert not any(s.has_what.values()) + gc.collect() assert not list(ws) From dd857b8dd23ff3043f5d5007649bce898c1b8d74 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 31 Mar 2022 19:42:05 +0100 Subject: [PATCH 75/81] Short variant of test_report.html (#6034) --- .github/workflows/test-report.yaml | 8 +- .gitignore | 2 + README.rst | 17 +- continuous_integration/scripts/test_report.py | 194 +++++++++++++----- 4 files changed, 161 insertions(+), 60 deletions(-) diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index e5d925cea33..26a3b27e665 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -2,7 +2,8 @@ name: Test Report on: schedule: - - cron: "47 6 * * *" + # Run 2h after the daily tests.yaml + - cron: "0 8,20 * * *" workflow_dispatch: jobs: @@ -38,9 +39,10 @@ jobs: - name: Generate report shell: bash -l {0} run: | - python continuous_integration/scripts/test_report.py + python continuous_integration/scripts/test_report.py --days 90 --nfails 1 -o test_report.html + python continuous_integration/scripts/test_report.py --days 7 --nfails 2 -o test_short_report.html mkdir deploy - mv test_report.html deploy/ + mv test_report.html test_short_report.html deploy/ - name: Deploy 🚀 uses: JamesIves/github-pages-deploy-action@4.1.7 diff --git a/.gitignore b/.gitignore index d7aa498cb5a..c7cf56216af 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,8 @@ tags .mypy_cache/ reports/ +test_report.* +test_short_report.html # Test failures will dump the cluster state in here test_cluster_dump/ diff --git a/README.rst b/README.rst index baf19563ad5..dc61e83f308 100644 --- a/README.rst +++ b/README.rst @@ -1,22 +1,25 @@ Distributed =========== -|Test Status| |Longitudinal Report| |Coverage| |Doc Status| |Discourse| |Version Status| |NumFOCUS| +|Test Status| |Longitudinal Report (full)| |Longitudinal Report (short)| |Coverage| |Doc Status| |Discourse| |Version Status| |NumFOCUS| A library for distributed computation. See documentation_ for more details. .. _documentation: https://distributed.dask.org .. |Test Status| image:: https://github.com/dask/distributed/workflows/Tests/badge.svg?branch=main :target: https://github.com/dask/distributed/actions?query=workflow%3A%22Tests%22 -.. |Longitudinal Report| image:: https://github.com/dask/distributed/workflows/Test%20Report/badge.svg?branch=main - :target: https://dask.github.io/distributed/test_report.html - :alt: Longitudinal test report -.. |Doc Status| image:: https://readthedocs.org/projects/distributed/badge/?version=latest - :target: https://distributed.dask.org - :alt: Documentation Status +.. |Longitudinal Report (full)| image:: https://github.com/dask/distributed/workflows/Test%20Report/badge.svg?branch=main + :target: https://dask.org/distributed/test_report.html + :alt: Longitudinal test report (full version) +.. |Longitudinal Report (short)| image:: https://github.com/dask/distributed/workflows/Test%20Report/badge.svg?branch=main + :target: https://dask.org/distributed/test_short_report.html + :alt: Longitudinal test report (short version) .. |Coverage| image:: https://codecov.io/gh/dask/distributed/branch/main/graph/badge.svg :target: https://codecov.io/gh/dask/distributed/branch/main :alt: Coverage status +.. |Doc Status| image:: https://readthedocs.org/projects/distributed/badge/?version=latest + :target: https://distributed.dask.org + :alt: Documentation Status .. |Discourse| image:: https://img.shields.io/discourse/users?logo=discourse&server=https%3A%2F%2Fdask.discourse.group :alt: Discuss Dask-related things and ask for help :target: https://dask.discourse.group diff --git a/continuous_integration/scripts/test_report.py b/continuous_integration/scripts/test_report.py index d19b214dddd..f99a1ebae57 100644 --- a/continuous_integration/scripts/test_report.py +++ b/continuous_integration/scripts/test_report.py @@ -1,10 +1,15 @@ from __future__ import annotations +import argparse import html import io import os import re +import shelve +import sys import zipfile +from collections.abc import Iterator +from typing import Any import altair import altair_saver @@ -22,7 +27,56 @@ } -def get_from_github(url, params={}): +def parse_args(argv: list[str] | None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--repo", + default="dask/distributed", + help="github repository", + ) + parser.add_argument( + "--branch", + default="main", + help="git branch", + ) + parser.add_argument( + "--events", + nargs="+", + default=["push", "schedule"], + help="github events", + ) + parser.add_argument( + "--days", + "-d", + type=int, + default=90, + help="Number of days to look back from now", + ) + parser.add_argument( + "--max-workflows", + type=int, + default=50, + help="Maximum number of workflows to fetch regardless of days", + ) + parser.add_argument( + "--nfails", + "-n", + type=int, + default=1, + help="Show test if it failed more than this many times", + ) + parser.add_argument( + "--output", + "-o", + default="test_report.html", + help="Output file name", + ) + return parser.parse_args(argv) + + +def get_from_github(url: str, params: dict[str, Any]) -> requests.Response: """ Make an authenticated request to the GitHub REST API. """ @@ -31,7 +85,7 @@ def get_from_github(url, params={}): return r -def maybe_get_next_page_path(response): +def maybe_get_next_page_path(response: requests.Response) -> str | None: """ If a response is paginated, get the url for the next page. """ @@ -48,11 +102,11 @@ def maybe_get_next_page_path(response): return next_page_path -def get_workflow_listing(repo="dask/distributed", branch="main", event="push"): +def get_workflow_listing(repo: str, branch: str, event: str, days: int): """ Get a list of workflow runs from GitHub actions. """ - since = str((pandas.Timestamp.now(tz="UTC") - pandas.Timedelta(days=90)).date()) + since = (pandas.Timestamp.now(tz="UTC") - pandas.Timedelta(days=days)).date() params = {"per_page": 100, "branch": branch, "event": event, "created": f">{since}"} r = get_from_github( f"https://api.github.com/repos/{repo}/actions/runs", params=params @@ -60,14 +114,14 @@ def get_workflow_listing(repo="dask/distributed", branch="main", event="push"): workflows = r.json()["workflow_runs"] next_page = maybe_get_next_page_path(r) while next_page: - r = get_from_github(next_page) - workflows = workflows + r.json()["workflow_runs"] + r = get_from_github(next_page, params) + workflows += r.json()["workflow_runs"] next_page = maybe_get_next_page_path(r) return workflows -def get_artifacts_for_workflow(run_id, repo="dask/distributed"): +def get_artifacts_for_workflow(run_id: str, repo: str) -> list: """ Get a list of artifacts from GitHub actions """ @@ -79,14 +133,14 @@ def get_artifacts_for_workflow(run_id, repo="dask/distributed"): artifacts = r.json()["artifacts"] next_page = maybe_get_next_page_path(r) while next_page: - r = get_from_github(next_page) - artifacts = workflows + r.json()["workflow_runs"] + r = get_from_github(next_page, params=params) + artifacts += r.json()["artifacts"] next_page = maybe_get_next_page_path(r) return artifacts -def suite_from_name(name: str): +def suite_from_name(name: str) -> str: """ Get a test suite name from an artifact name. The artifact can have matrix partitions, pytest marks, etc. Basically, @@ -95,12 +149,12 @@ def suite_from_name(name: str): return "-".join(name.split("-")[:3]) -def download_and_parse_artifact(url): +def download_and_parse_artifact(url: str): """ Download the artifact at the url parse it. """ try: - r = get_from_github(url) + r = get_from_github(url, params={}) f = zipfile.ZipFile(io.BytesIO(r.content)) run = junitparser.JUnitXml.fromstring(f.read(f.filelist[0].filename)) return run @@ -109,7 +163,7 @@ def download_and_parse_artifact(url): return None -def dataframe_from_jxml(run): +def dataframe_from_jxml(run: list) -> pandas.DataFrame: """ Turn a parsed JXML into a pandas dataframe """ @@ -161,13 +215,16 @@ def dedup(group): return df.groupby(["file", "test"]).agg(dedup) -if __name__ == "__main__": - if not TOKEN: - raise RuntimeError("Failed to find a GitHub Token") - print("Getting all recent workflows...") - workflows = get_workflow_listing(event="push") + get_workflow_listing( - event="schedule" - ) +def download_and_parse_artifacts( + repo: str, branch: str, events: list[str], days: int, max_workflows: int +) -> Iterator[pandas.DataFrame]: + + print("Getting workflows list...") + workflows = [] + for event in events: + workflows += get_workflow_listing( + repo=repo, branch=branch, event=event, days=days + ) # Filter the workflows listing to be in the retention period, # and only be test runs (i.e., no linting) that completed. @@ -176,19 +233,20 @@ def dedup(group): for w in workflows if ( pandas.to_datetime(w["created_at"]) - > pandas.Timestamp.now(tz="UTC") - pandas.Timedelta(days=90) + > pandas.Timestamp.now(tz="UTC") - pandas.Timedelta(days=days) and w["conclusion"] != "cancelled" and w["name"].lower() == "tests" ) ] + print(f"Found {len(workflows)} workflows") # Each workflow processed takes ~10-15 API requests. To avoid being # rate limited by GitHub (1000 requests per hour) we choose just the # most recent N runs. This also keeps the viz size from blowing up. - workflows = sorted(workflows, key=lambda w: w["created_at"])[-50:] + workflows = sorted(workflows, key=lambda w: w["created_at"])[-max_workflows:] + print(f"Fetching artifact listing for the {len(workflows)} most recent workflows") - print("Getting the artifact listing for each workflow...") for w in workflows: - artifacts = get_artifacts_for_workflow(w["id"]) + artifacts = get_artifacts_for_workflow(w["id"], repo=repo) # We also upload timeout reports as artifacts, but we don't want them here. w["artifacts"] = [ a @@ -196,38 +254,68 @@ def dedup(group): if "timeouts" not in a["name"] and "cluster_dumps" not in a["name"] ] - print("Downloading and parsing artifacts...") - for w in workflows: - w["dfs"] = [] - for a in w["artifacts"]: - xml = download_and_parse_artifact(a["archive_download_url"]) - df = dataframe_from_jxml(xml) if xml else None - # Note: we assign a column with the workflow timestamp rather than the - # artifact timestamp so that artifacts triggered under the same workflow - # can be aligned according to the same trigger time. - if df is not None: - df = df.assign( - name=a["name"], - suite=suite_from_name(a["name"]), - date=w["created_at"], - url=w["html_url"], - ) - w["dfs"].append(df) - - # Make a top-level dict of dataframes, mapping test name to a dataframe - # of all check suites that ran that test. - # Note: we drop **all** tests which did not have at least one failure. + nartifacts = sum(len(w["artifacts"]) for w in workflows) + ndownloaded = 0 + print(f"Downloading and parsing {nartifacts} artifacts...") + + with shelve.open("test_report") as cache: + for w in workflows: + w["dfs"] = [] + for a in w["artifacts"]: + url = a["archive_download_url"] + df: pandas.DataFrame | None + try: + df = cache[url] + except KeyError: + xml = download_and_parse_artifact(url) + if xml: + df = dataframe_from_jxml(xml) + # Note: we assign a column with the workflow timestamp rather + # than the artifact timestamp so that artifacts triggered under + # the same workflow can be aligned according to the same trigger + # time. + df = df.assign( + name=a["name"], + suite=suite_from_name(a["name"]), + date=w["created_at"], + url=w["html_url"], + ) + else: + df = None + cache[url] = df + + if df is not None: + yield df + + ndownloaded += 1 + if ndownloaded and not ndownloaded % 20: + print(f"{ndownloaded}... ", end="") + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + if not TOKEN: + raise RuntimeError("Failed to find a GitHub Token") + + # Note: we drop **all** tests which did not have at least failures. # This is because, as nice as a block of green tests can be, there are # far too many tests to visualize at once, so we only want to look at # flaky tests. If the test suite has been doing well, this chart should # dwindle to nothing! - dfs = [] - for w in workflows: - dfs.extend([df for df in w["dfs"]]) + dfs = list( + download_and_parse_artifacts( + repo=args.repo, + branch=args.branch, + events=args.events, + days=args.days, + max_workflows=args.max_workflows, + ) + ) + total = pandas.concat(dfs, axis=0) grouped = ( total.groupby(total.index) - .filter(lambda g: (g.status == "x").any()) + .filter(lambda g: (g.status == "x").sum() >= args.nfails) .reset_index() .assign(test=lambda df: df.file + "." + df.test) .groupby("test") @@ -299,11 +387,17 @@ def dedup(group): .configure_title(anchor="start") .resolve_scale(x="shared") # enforce aligned x axes ) + chart.title = " ".join(argv if argv is not None else sys.argv) + altair_saver.save( chart, - "test_report.html", + args.output, embed_options={ "renderer": "svg", # Makes the text searchable "loader": {"target": "_blank"}, # Open hrefs in a new window }, ) + + +if __name__ == "__main__": + main() From 9efb27cf9338812ea3c11fb138c5c38cfab7f245 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 14:48:44 -0500 Subject: [PATCH 76/81] Add test for bad disk --- distributed/shuffle/multi_comm.py | 2 ++ distributed/shuffle/multi_file.py | 31 +++++++++++++++-------- distributed/shuffle/shuffle_extension.py | 8 +++++- distributed/shuffle/tests/test_shuffle.py | 25 ++++++++++++++++++ 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 6fd2bc337d7..60d3a4a3869 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -184,6 +184,8 @@ async def flush(self): """ We don't expect any more data, wait until everything is flushed through """ + if self._exception: + raise self._exception while self.shards: await asyncio.sleep(0.05) diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 9667193f3fd..9509bd7209a 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -94,6 +94,7 @@ def __init__( self._communicate_future = asyncio.create_task(self.communicate()) self._loop = loop or asyncio.get_event_loop() + self._exception = None @property def queue(self): @@ -116,6 +117,9 @@ async def put(self, data: dict[str, list[object]]): A dictionary mapping destinations to lists of objects that should be written to that destination """ + if self._exception: + raise self._exception + this_size = 0 for id, shard in data.items(): size = self.sizeof(shard) @@ -194,18 +198,19 @@ async def process(self, id: str, shards: list, size: int): self.active.add(id) - def _(): - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: - for shard in shards: - self.dump(shard, f) - # os.fsync(f) # TODO: maybe? - start = time.time() - with self.time("write"): - _() - # await offload(_) + try: + with self.time("write"): + with open( + self.directory / str(id), mode="ab", buffering=100_000_000 + ) as f: + for shard in shards: + self.dump(shard, f) + # os.fsync(f) # TODO: maybe? + except Exception as e: + self._exception = e + self._done = True + stop = time.time() self.diagnostics["avg_size"] = ( @@ -225,6 +230,8 @@ def _(): def read(self, id): """Read a complete file back into memory""" + if self._exception: + raise self._exception parts = [] try: @@ -251,6 +258,8 @@ def read(self, id): async def flush(self): """Wait until all writes are finished""" + if self._exception: + raise self._exception while self.shards: await asyncio.sleep(0.05) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 2ec51875a71..08c26eea8c5 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -121,6 +121,7 @@ async def send(address, shards): self.transferred = False self.total_recvd = 0 self.start_time = time.time() + self._exception: Exception | None = None @contextlib.contextmanager def time(self, name: str): @@ -166,6 +167,8 @@ async def receive(self, data: list[pa.Buffer]) -> None: # but barriers on other workers might still be running and sending us # data # assert not self.transferred, "`receive` called after barrier task" + if self._exception: + raise self._exception import pyarrow as pa self.total_recvd += sum(map(len, data)) @@ -189,7 +192,10 @@ async def receive(self, data: list[pa.Buffer]) -> None: for k, v in groups.items() } ) - await self.multi_file.put(groups) + try: + await self.multi_file.put(groups) + except Exception as e: + self._exception = e def add_partition(self, data: pd.DataFrame) -> None: with self.time("cpu"): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index e6d1711fa11..074c6fe895a 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,9 +1,11 @@ import asyncio import io +import os from collections import defaultdict import pandas as pd import pyarrow as pa +import pytest import dask import dask.dataframe as dd @@ -49,6 +51,29 @@ async def test_concurrent(c, s, a, b): assert x == y +@gen_cluster(client=True) +async def test_bad_disk(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + while not a.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + os.chmod(a.local_directory, 0o444) + while not b.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + os.chmod(b.local_directory, 0o444) + with pytest.raises(PermissionError) as e: + out = await c.compute(out) + + assert a.local_directory in str(e.value) or b.local_directory in str(e.value) + + @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() From 69bed31502c3c8e24639d7ca91ff63f4e1d1cbf9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 15:00:22 -0500 Subject: [PATCH 77/81] Support exceptions in MultiComm --- distributed/shuffle/multi_comm.py | 11 +++++-- distributed/shuffle/tests/test_shuffle.py | 36 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 60d3a4a3869..b24774fda94 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -68,6 +68,7 @@ def __init__( self._loop = loop or asyncio.get_event_loop() self._communicate_future = asyncio.create_task(self.communicate()) + self._exception = None @property def queue(self): @@ -89,6 +90,8 @@ def put(self, data: dict): If we're out of space then we block in order to enforce backpressure. """ + if self._exception: + raise self._exception with self.lock: for address, shards in data.items(): size = sum(map(len, shards)) @@ -164,8 +167,12 @@ async def process(self, address: str, shards: list, size: int): # while (time.time() // 5 % 4) == 0: # await asyncio.sleep(0.1) start = time.time() - with self.time("send"): - await self.send(address, [b"".join(shards)]) + try: + with self.time("send"): + await self.send(address, [b"".join(shards)]) + except Exception as e: + self._exception = e + self._done = True stop = time.time() self.diagnostics["avg_size"] = ( 0.95 * self.diagnostics["avg_size"] + 0.05 * size diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 074c6fe895a..eca57aab2b0 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,6 +74,42 @@ async def test_bad_disk(c, s, a, b): assert a.local_directory in str(e.value) or b.local_directory in str(e.value) +@pytest.mark.slow +@gen_cluster(client=True) +async def test_crashed_worker(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + while not a.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + while not b.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + + while ( + len( + [ + ts + for ts in s.tasks.values() + if "shuffle_transfer" in ts.key and ts.state == "memory" + ] + ) + < 3 + ): + await asyncio.sleep(0.01) + await b.close() + + with pytest.raises(Exception) as e: + out = await c.compute(out) + + assert a.address in str(e.value) or b.address in str(e.value) + + @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() From ec710914afa9794e0af695021b0dc467064c2b76 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 15:26:38 -0500 Subject: [PATCH 78/81] add unit tests for exceptions --- distributed/shuffle/tests/test_multi_comm.py | 21 ++++++++++++++++++++ distributed/shuffle/tests/test_multi_file.py | 19 ++++++++++++++++++ distributed/shuffle/tests/test_shuffle.py | 4 ---- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py index cff50f8ff92..d048b1faad8 100644 --- a/distributed/shuffle/tests/test_multi_comm.py +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -1,3 +1,4 @@ +import asyncio from collections import defaultdict import pytest @@ -20,3 +21,23 @@ async def send(address, shards): assert b"".join(d["x"]) == b"0" * 2000 assert b"".join(d["y"]) == b"1" * 1000 + + +@pytest.mark.asyncio +async def test_exceptions(tmp_path): + d = defaultdict(list) + + async def send(address, shards): + raise Exception(123) + + mc = MultiComm(send=send) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + while not mc._exception: + await asyncio.sleep(0.1) + + with pytest.raises(Exception, match="123"): + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + with pytest.raises(Exception, match="123"): + await mc.flush() diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py index 7a9f6f17914..0e042c0cc2e 100644 --- a/distributed/shuffle/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -1,3 +1,4 @@ +import asyncio import os import pytest @@ -49,3 +50,21 @@ async def test_many(tmp_path, count): assert out == str(i).encode() * 100 * 10 assert not os.path.exists(tmp_path) + + +@pytest.mark.asyncio +async def test_exceptions(tmp_path): + def dump(data, f): + raise Exception(123) + + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + while not mf._exception: + await asyncio.sleep(0.1) + + with pytest.raises(Exception, match="123"): + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + with pytest.raises(Exception, match="123"): + await mf.flush() diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index eca57aab2b0..649f5585ab9 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -86,10 +86,6 @@ async def test_crashed_worker(c, s, a, b): ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") out = out.persist() - while not a.extensions["shuffle"].shuffles: - await asyncio.sleep(0.01) - while not b.extensions["shuffle"].shuffles: - await asyncio.sleep(0.01) while ( len( From 016ed25129ab3d5c4342b4271cf2aeeb9481de01 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 15:33:48 -0500 Subject: [PATCH 79/81] cleanup files properly --- distributed/shuffle/tests/test_shuffle.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 649f5585ab9..7d078d21ccb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -62,16 +62,21 @@ async def test_bad_disk(c, s, a, b): ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") out = out.persist() + original_stat = os.stat(a.local_directory).st_mode while not a.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) os.chmod(a.local_directory, 0o444) while not b.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) os.chmod(b.local_directory, 0o444) - with pytest.raises(PermissionError) as e: - out = await c.compute(out) - - assert a.local_directory in str(e.value) or b.local_directory in str(e.value) + try: + with pytest.raises(PermissionError) as e: + out = await c.compute(out) + + assert a.local_directory in str(e.value) or b.local_directory in str(e.value) + finally: + os.chmod(a.local_directory, original_stat) + os.chmod(b.local_directory, original_stat) @pytest.mark.slow From 8ee56056428d39f2c417685b6a03a365b81e2254 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 15:39:42 -0500 Subject: [PATCH 80/81] cleanup extra futures --- distributed/shuffle/multi_comm.py | 4 +++- distributed/shuffle/multi_file.py | 2 ++ distributed/shuffle/tests/test_shuffle.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index b24774fda94..4307cb275b4 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -192,7 +192,10 @@ async def flush(self): We don't expect any more data, wait until everything is flushed through """ if self._exception: + await self._communicate_future + await asyncio.gather(*self._futures) raise self._exception + while self.shards: await asyncio.sleep(0.05) @@ -202,7 +205,6 @@ async def flush(self): assert not self.total_size self._done = True - await self._communicate_future @contextlib.contextmanager diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py index 9509bd7209a..cf323d174ab 100644 --- a/distributed/shuffle/multi_file.py +++ b/distributed/shuffle/multi_file.py @@ -259,6 +259,8 @@ def read(self, id): async def flush(self): """Wait until all writes are finished""" if self._exception: + await self._communicate_future + await asyncio.gather(*self._futures) raise self._exception while self.shards: await asyncio.sleep(0.05) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 7d078d21ccb..59e3ca7d47e 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -129,6 +129,7 @@ async def test_heartbeat(c, s, a, b): await a.heartbeat() [s] = s.extensions["shuffle"].shuffles.values() + await out def test_processing_chain(): From fa235ee00c893ba0d931a742a38491c98c926c29 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 18:53:00 -0500 Subject: [PATCH 81/81] Support windows in tests (hopefully) --- distributed/shuffle/tests/test_shuffle.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 59e3ca7d47e..027d3d24bdc 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,12 +1,13 @@ import asyncio import io -import os +import shutil from collections import defaultdict import pandas as pd -import pyarrow as pa import pytest +pa = pytest.importorskip("pyarrow") + import dask import dask.dataframe as dd @@ -62,21 +63,17 @@ async def test_bad_disk(c, s, a, b): ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") out = out.persist() - original_stat = os.stat(a.local_directory).st_mode while not a.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) - os.chmod(a.local_directory, 0o444) + shutil.rmtree(a.local_directory) + while not b.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) - os.chmod(b.local_directory, 0o444) - try: - with pytest.raises(PermissionError) as e: - out = await c.compute(out) - - assert a.local_directory in str(e.value) or b.local_directory in str(e.value) - finally: - os.chmod(a.local_directory, original_stat) - os.chmod(b.local_directory, original_stat) + shutil.rmtree(b.local_directory) + with pytest.raises(FileNotFoundError) as e: + out = await c.compute(out) + + assert a.local_directory in str(e.value) or b.local_directory in str(e.value) @pytest.mark.slow