diff --git a/distributed/_concurrent_futures_thread.py b/distributed/_concurrent_futures_thread.py new file mode 100644 index 00000000000..827c36e65d8 --- /dev/null +++ b/distributed/_concurrent_futures_thread.py @@ -0,0 +1,156 @@ +# This was copied from CPython 3.6 + +# Copyright 2009 Brian Quinlan. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Implements ThreadPoolExecutor.""" + +__author__ = 'Brian Quinlan (brian@sweetapp.com)' + +import atexit +from concurrent.futures import _base +import itertools +try: + import queue +except ImportError: + import Queue as queue +import threading +import weakref +import os + +# Workers are created as daemon threads. This is done to allow the interpreter +# to exit when there are still idle threads in a ThreadPoolExecutor's thread +# pool (i.e. shutdown() was not called). However, allowing workers to die with +# the interpreter has two undesirable properties: +# - The workers would still be running during interpreter shutdown, +# meaning that they would fail in unpredictable ways. +# - The workers could be killed while evaluating a work item, which could +# be bad if the callable being evaluated has external side-effects e.g. +# writing to a file. +# +# To work around this problem, an exit handler is installed which tells the +# workers to exit when their work queues are empty and then waits until the +# threads finish. + +_threads_queues = weakref.WeakKeyDictionary() +_shutdown = False + +def _python_exit(): + global _shutdown + _shutdown = True + items = list(_threads_queues.items()) + for t, q in items: + q.put(None) + for t, q in items: + t.join() + +atexit.register(_python_exit) + +class _WorkItem(object): + def __init__(self, future, fn, args, kwargs): + self.future = future + self.fn = fn + self.args = args + self.kwargs = kwargs + + def run(self): + if not self.future.set_running_or_notify_cancel(): + return + + try: + result = self.fn(*self.args, **self.kwargs) + except BaseException as e: + self.future.set_exception(e) + else: + self.future.set_result(result) + +def _worker(executor_reference, work_queue): + try: + while True: + work_item = work_queue.get(block=True) + if work_item is not None: + work_item.run() + # Delete references to object. See issue16284 + del work_item + continue + executor = executor_reference() + # Exit if: + # - The interpreter is shutting down OR + # - The executor that owns the worker has been collected OR + # - The executor that owns the worker has been shutdown. + if _shutdown or executor is None or executor._shutdown: + # Notice other workers + work_queue.put(None) + return + del executor + except BaseException: + _base.LOGGER.critical('Exception in worker', exc_info=True) + +class ThreadPoolExecutor(_base.Executor): + + # Used to assign unique thread names when thread_name_prefix is not supplied. + _counter = itertools.count() + + def __init__(self, max_workers=None, thread_name_prefix=''): + """Initializes a new ThreadPoolExecutor instance. + + Args: + max_workers: The maximum number of threads that can be used to + execute the given calls. + thread_name_prefix: An optional name prefix to give our threads. + """ + if max_workers is None: + # Use this number because ThreadPoolExecutor is often + # used to overlap I/O instead of CPU work. + max_workers = (os.cpu_count() or 1) * 5 + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0") + + self._max_workers = max_workers + self._work_queue = queue.Queue() + self._threads = set() + self._shutdown = False + self._shutdown_lock = threading.Lock() + self._thread_name_prefix = (thread_name_prefix or + ("ThreadPoolExecutor-%d" % next(self._counter))) + + def submit(self, fn, *args, **kwargs): + with self._shutdown_lock: + if self._shutdown: + raise RuntimeError('cannot schedule new futures after shutdown') + + f = _base.Future() + w = _WorkItem(f, fn, args, kwargs) + + self._work_queue.put(w) + self._adjust_thread_count() + return f + submit.__doc__ = _base.Executor.submit.__doc__ + + def _adjust_thread_count(self): + # When the executor gets lost, the weakref callback will wake up + # the worker threads. + def weakref_cb(_, q=self._work_queue): + q.put(None) + # TODO(bquinlan): Should avoid creating new threads if there are more + # idle threads than items in the work queue. + num_threads = len(self._threads) + if num_threads < self._max_workers: + thread_name = '%s_%d' % (self._thread_name_prefix or self, + num_threads) + t = threading.Thread(name=thread_name, target=_worker, + args=(weakref.ref(self, weakref_cb), + self._work_queue)) + t.daemon = True + t.start() + self._threads.add(t) + _threads_queues[t] = self._work_queue + + def shutdown(self, wait=True): + with self._shutdown_lock: + self._shutdown = True + self._work_queue.put(None) + if wait: + for t in self._threads: + t.join() + shutdown.__doc__ = _base.Executor.shutdown.__doc__ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1d9788e894f..9bfee8c7d3e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3197,7 +3197,7 @@ def test_Client_clears_references_after_restart(c, s, a, b): def test_get_stops_work_after_error(loop): - with cluster() as (s, [a, b]): + with cluster(active_rpc_timeout=10) as (s, [a, b]): with Client(s['address'], loop=loop) as c: with pytest.raises(RuntimeError): c.get({'x': (throws, 1), 'y': (sleep, 1.5)}, ['x', 'y']) diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 88059e7dd76..e54644bddae 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -210,7 +210,7 @@ def test_unsupported_arguments(loop): def test_shutdown(loop): - with cluster() as (s, [a, b]): + with cluster(active_rpc_timeout=10) as (s, [a, b]): with Client(s['address'], loop=loop) as c: # shutdown(wait=True) waits for pending tasks to finish e = c.get_executor() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 347376c332f..2039c75afc6 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -311,7 +311,6 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest): assert not any(w.task_state for w in rest) -@pytest.mark.skip(reason='leaks large amount of memory') @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) def test_steal_when_more_tasks(c, s, a, *rest): s.extensions['stealing']._pc.callback_time = 20 @@ -326,7 +325,6 @@ def test_steal_when_more_tasks(c, s, a, *rest): assert any(w.task_state for w in rest) -@pytest.mark.skip(reason='leaks large amount of memory') @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10) def test_steal_more_attractive_tasks(c, s, a, *rest): def slow2(x): @@ -474,8 +472,6 @@ def test_restart(c, s, a, b): assert not any(x for L in steal.stealable.values() for x in L) -# @pytest.mark.avoid_travis # leaks large amounts of memory -@pytest.mark.skip(reason='leaks memory') @gen_cluster(client=True) def test_steal_communication_heavy_tasks(c, s, a, b): s.task_duration['slowadd'] = 0.001 diff --git a/distributed/tests/test_threadpoolexecutor.py b/distributed/tests/test_threadpoolexecutor.py index 3b46724d34f..bd7024278ff 100644 --- a/distributed/tests/test_threadpoolexecutor.py +++ b/distributed/tests/test_threadpoolexecutor.py @@ -24,3 +24,25 @@ def f(): while all(t.is_alive() for t in threads): sleep(0.01) assert time() < start + 1 + + +def test_shutdown_timeout(): + e = ThreadPoolExecutor(1) + futures = [e.submit(sleep, 0.1 * i) for i in range(1, 3, 1)] + sleep(0.01) + + start = time() + e.shutdown() + end = time() + assert end - start > 0.1 + + +def test_shutdown_timeout_raises(): + e = ThreadPoolExecutor(1) + futures = [e.submit(sleep, 0.1 * i) for i in range(1, 3, 1)] + sleep(0.05) + + start = time() + e.shutdown(timeout=0.1) + end = time() + assert end - start > 0.05 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d075d89b2ab..21c92aca6e2 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -402,8 +402,9 @@ def test_spill_by_default(c, s, w): da = pytest.importorskip('dask.array') x = da.ones(int(TOTAL_MEMORY * 0.7), chunks=10000000, dtype='u1') y = c.persist(x) - yield _wait(y) + yield wait(y) assert len(w.data.slow) # something is on disk + del x, y @gen_cluster(ncores=[('127.0.0.1', 1)], diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index d59ce49723b..06a1dd65da6 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -22,11 +22,12 @@ """ from __future__ import print_function, division, absolute_import -from concurrent.futures import thread +from . import _concurrent_futures_thread as thread import logging import threading from .compatibility import get_thread_identity +from .metrics import time logger = logging.getLogger(__name__) @@ -64,9 +65,19 @@ def _adjust_thread_count(self): self._threads.add(t) t.start() - def shutdown(self, wait=True): + def shutdown(self, wait=True, timeout=None): with threads_lock: - thread.ThreadPoolExecutor.shutdown(self, wait=wait) + with self._shutdown_lock: + self._shutdown = True + self._work_queue.put(None) + if timeout is not None: + deadline = time() + timeout + for t in self._threads: + if timeout is not None: + timeout2 = max(deadline - time(), 0) + else: + timeout2 = None + t.join(timeout=timeout2) def secede(): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 32c2248d54a..40f2b476c8f 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -310,7 +310,7 @@ def run_nanny(q, scheduler_q, **kwargs): @contextmanager -def check_active_rpc(loop, active_rpc_timeout=0): +def check_active_rpc(loop, active_rpc_timeout=1): if rpc.active > 0: # Streams from a previous test dangling around? gc.collect() @@ -339,7 +339,7 @@ def wait_a_bit(): @contextmanager -def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=0, +def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_kwargs={}, should_check_state=True): ws = weakref.WeakSet() before = process_state() @@ -522,7 +522,7 @@ def check_state(before, after): break start = time() - while after['used-memory'] > before['used-memory'] + 1e8: + while after['used-memory'] > before['used-memory'] + 1e7: gc.collect() sleep(0.10) after = process_state() @@ -535,7 +535,7 @@ def check_state(before, after): "total leaked total", (after['used-memory'] - initial_state['used-memory']) / 1e6) # , end=' ') total_diff = after['used-memory'] - initial_state['used-memory'] - assert total_diff < 3e9, total_diff + assert total_diff < 2e9, total_diff @gen.coroutine @@ -592,7 +592,7 @@ def iscoroutinefunction(f): def gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)], scheduler='127.0.0.1', timeout=10, security=None, Worker=Worker, client=False, scheduler_kwargs={}, - worker_kwargs={}, active_rpc_timeout=0, should_check_state=True): + worker_kwargs={}, active_rpc_timeout=1, should_check_state=True): from distributed import Client """ Coroutine test with small cluster diff --git a/distributed/worker.py b/distributed/worker.py index feacd928710..f9667047e8d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -330,7 +330,10 @@ def _close(self, report=True, timeout=10, nanny=True): self.scheduler.unregister(address=self.address), io_loop=self.loop) self.scheduler.close_rpc() - self.executor.shutdown(wait=False) + if isinstance(self.executor, ThreadPoolExecutor): + self.executor.shutdown(timeout=timeout) + else: + self.executor.shutdown(wait=False) if os.path.exists(self.local_dir): shutil.rmtree(self.local_dir) diff --git a/setup.cfg b/setup.cfg index 6b8f9e9250a..29e2f0c6814 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ # https://flake8.readthedocs.io/en/latest/user/error-codes.html # Note: there cannot be spaces after comma's here -exclude = __init__.py +exclude = __init__.py,distributed/_concurrent_futures_thread.py ignore = # Extra space in brackets E20,