Skip to content
12 changes: 12 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,18 @@ properties:
Should be used for variables that must be set before
process startup, interpreter startup, or imports.

shuffle:
type: object
description: |
Low-level settings for control of p2p shuffle
properties:
output_max_buffer_size:
type: [string, integer, 'null']
description: |
Maximum size of the in-memory output buffer for p2p
shuffles before a worker writes output to disk. If
``None`` then a default of one quarter of the worker's
total memory is used.
client:
type: object
description: |
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ distributed:
MKL_NUM_THREADS: 1
OPENBLAS_NUM_THREADS: 1

shuffle:
output_max_buffer_size: null # Size of shuffle output memory buffer

client:
heartbeat: 5s # Interval between client heartbeats
scheduler-info-interval: 2s # Interval between scheduler-info updates
Expand Down
58 changes: 45 additions & 13 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import pathlib
import shutil
from collections import defaultdict

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
Expand Down Expand Up @@ -36,12 +37,19 @@ class DiskShardsBuffer(ShardsBuffer):
to be processed exceeds this limit, then the buffer will block
until below the threshold. See :meth:`.write` for the
implementation of this scheme.
max_in_memory_buffer_size : int, optional
Size of in-memory buffer to use before flushing to disk. If
configured, incoming shards will first be moved to memory
Comment thread
wence- marked this conversation as resolved.
rather than immediately written to disk. This can provide for
speedups when an entire shuffle fits in memory.
"""

def __init__(
self,
directory: str | pathlib.Path,
memory_limiter: ResourceLimiter | None = None,
*,
max_in_memory_buffer_size: int = 0,
):
super().__init__(
memory_limiter=memory_limiter,
Expand All @@ -50,6 +58,9 @@ def __init__(
)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)
self._in_memory = 0
self._memory_buf: defaultdict[str, list[bytes]] = defaultdict(list)
self.max_in_memory_buffer_size = max_in_memory_buffer_size

async def _process(self, id: str, shards: list[bytes]) -> None:
"""Write one buffer to file
Expand All @@ -68,31 +79,52 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with log_errors():
# Consider boosting total_size a bit here to account for duplication
with self.time("write"):
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)

def read(self, id: int | str) -> bytes:
"""Read a complete file back into memory"""
if not self.max_in_memory_buffer_size:
# Fast path if we're always hitting the disk
self._write(id, shards)
else:
while shards:
if self._in_memory < self.max_in_memory_buffer_size:
self._memory_buf[id].append(newdata := shards.pop())
self._in_memory += len(newdata)
else:
# Flush old data
# This could be offloaded to a background
# task at the cost of going further over
# the soft memory limit.
for k, v in self._memory_buf.items():
self._write(k, v)
self._memory_buf.clear()
self._in_memory = 0
Comment thread
wence- marked this conversation as resolved.

def _write(self, id: str, shards: list[bytes]) -> None:
with open(self.directory / str(id), mode="ab", buffering=100_000_000) as f:
for s in shards:
f.write(s)

def read(self, id: str) -> bytes:
"""Read a complete file back into memory, concatting with any
in memory parts"""
self.raise_on_exception()
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

data = self._memory_buf.pop(id, [])
try:
with self.time("read"):
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
data = f.read()
size = f.tell()
data.append(f.read())
except FileNotFoundError:
raise KeyError(id)
if not data:
# Neither disk nor in memory
raise KeyError(id)

if data:
self.bytes_read += size
return data
buf = b"".join(data)
self.bytes_read += len(buf)
return buf
else:
raise KeyError(id)

Expand Down
22 changes: 22 additions & 0 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import toolz

from dask import config
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None,
):
self.id = id
self.run_id = run_id
Expand All @@ -78,9 +80,23 @@ def __init__(
self.scheduler = scheduler
self.closed = False

buffer_size: int | str | None = config.get(
"distributed.shuffle.output_max_buffer_size"
)
if buffer_size is None:
if worker_memory_limit is None:
# No configuration and no known worker memory limit
# Safe default is "no in-memory buffering"
buffer_size = 0
else:
buffer_size = worker_memory_limit // 4
else:
buffer_size = parse_bytes(buffer_size)

self._disk_buffer = DiskShardsBuffer(
directory=directory,
memory_limiter=memory_limiter_disk,
max_in_memory_buffer_size=buffer_size,
)

self._comm_buffer = CommShardsBuffer(
Expand Down Expand Up @@ -281,6 +297,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None = None,
):
from dask.array.rechunk import _old_to_new

Expand All @@ -295,6 +312,7 @@ def __init__(
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
worker_memory_limit=worker_memory_limit,
)
self.old = old
self.new = new
Expand Down Expand Up @@ -436,6 +454,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None = None,
):
import pandas as pd

Expand All @@ -450,6 +469,7 @@ def __init__(
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
worker_memory_limit=worker_memory_limit,
)
self.column = column
self.schema = schema
Expand Down Expand Up @@ -818,6 +838,7 @@ async def _(
scheduler=self.worker.scheduler,
memory_limiter_disk=self.memory_limiter_disk,
memory_limiter_comms=self.memory_limiter_comms,
worker_memory_limit=self.worker.memory_manager.memory_limit,
)
elif result["type"] == ShuffleType.ARRAY_RECHUNK:
shuffle = ArrayRechunkRun(
Expand All @@ -837,6 +858,7 @@ async def _(
scheduler=self.worker.scheduler,
memory_limiter_disk=self.memory_limiter_disk,
memory_limiter_comms=self.memory_limiter_comms,
worker_memory_limit=self.worker.memory_manager.memory_limit,
)
else: # pragma: no cover
raise TypeError(result["type"])
Expand Down
Loading