diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 417b7bd5be8..ed21507e041 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -19,6 +19,7 @@ distributed: idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes" transition-log-length: 100000 work-stealing: True # workers should steal tasks from each other + work-stealing-interval: 100ms # Callback time for work stealing worker-ttl: null # like '60s'. Time to live for workers. They must heartbeat faster than this pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] diff --git a/distributed/stealing.py b/distributed/stealing.py index e3537f647bf..b14a2a8de6d 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -6,7 +6,7 @@ import dask from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin -from .utils import log_errors, PeriodicCallback +from .utils import log_errors, parse_timedelta, PeriodicCallback try: from cytoolz import topk @@ -40,8 +40,15 @@ def __init__(self, scheduler): for worker in scheduler.workers: self.add_worker(worker=worker) + # `callback_time` is in milliseconds + callback_time = 1000 * parse_timedelta( + dask.config.get("distributed.scheduler.work-stealing-interval"), + default="ms", + ) pc = PeriodicCallback( - callback=self.balance, callback_time=100, io_loop=self.scheduler.loop + callback=self.balance, + callback_time=callback_time, + io_loop=self.scheduler.loop, ) self._pc = pc self.scheduler.periodic_callbacks["stealing"] = pc diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a6a19332f5f..b017bff4371 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -9,6 +9,7 @@ from toolz import sliding_window, concat from tornado import gen +import dask from distributed import Nanny, Worker, wait, worker_client from distributed.config import config from distributed.metrics import time @@ -676,3 +677,20 @@ def test_lose_task(c, s, a, b): out = log.getvalue() assert "Error" not in out + + +@gen_cluster(client=True) +def test_worker_stealing_interval(c, s, a, b): + from distributed.scheduler import WorkStealing + + ws = WorkStealing(s) + assert ws._pc.callback_time == 100 + + with dask.config.set({"distributed.scheduler.work-stealing-interval": "500ms"}): + ws = WorkStealing(s) + assert ws._pc.callback_time == 500 + + # Default unit is `ms` + with dask.config.set({"distributed.scheduler.work-stealing-interval": 2}): + ws = WorkStealing(s) + assert ws._pc.callback_time == 2