Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions distributed/_concurrent_futures_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# This was copied from CPython 3.6

# Copyright 2009 Brian Quinlan. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.

"""Implements ThreadPoolExecutor."""

__author__ = 'Brian Quinlan (brian@sweetapp.com)'

import atexit
from concurrent.futures import _base
import itertools
try:
import queue
except ImportError:
import Queue as queue
import threading
import weakref
import os

# Workers are created as daemon threads. This is done to allow the interpreter
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
# pool (i.e. shutdown() was not called). However, allowing workers to die with
# the interpreter has two undesirable properties:
# - The workers would still be running during interpreter shutdown,
# meaning that they would fail in unpredictable ways.
# - The workers could be killed while evaluating a work item, which could
# be bad if the callable being evaluated has external side-effects e.g.
# writing to a file.
#
# To work around this problem, an exit handler is installed which tells the
# workers to exit when their work queues are empty and then waits until the
# threads finish.

_threads_queues = weakref.WeakKeyDictionary()
_shutdown = False

def _python_exit():
global _shutdown
_shutdown = True
items = list(_threads_queues.items())
for t, q in items:
q.put(None)
for t, q in items:
t.join()

atexit.register(_python_exit)

class _WorkItem(object):
def __init__(self, future, fn, args, kwargs):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs

def run(self):
if not self.future.set_running_or_notify_cancel():
return

try:
result = self.fn(*self.args, **self.kwargs)
except BaseException as e:
self.future.set_exception(e)
else:
self.future.set_result(result)

def _worker(executor_reference, work_queue):
try:
while True:
work_item = work_queue.get(block=True)
if work_item is not None:
work_item.run()
# Delete references to object. See issue16284
del work_item
continue
executor = executor_reference()
# Exit if:
# - The interpreter is shutting down OR
# - The executor that owns the worker has been collected OR
# - The executor that owns the worker has been shutdown.
if _shutdown or executor is None or executor._shutdown:
# Notice other workers
work_queue.put(None)
return
del executor
except BaseException:
_base.LOGGER.critical('Exception in worker', exc_info=True)

class ThreadPoolExecutor(_base.Executor):

# Used to assign unique thread names when thread_name_prefix is not supplied.
_counter = itertools.count()

def __init__(self, max_workers=None, thread_name_prefix=''):
"""Initializes a new ThreadPoolExecutor instance.

Args:
max_workers: The maximum number of threads that can be used to
execute the given calls.
thread_name_prefix: An optional name prefix to give our threads.
"""
if max_workers is None:
# Use this number because ThreadPoolExecutor is often
# used to overlap I/O instead of CPU work.
max_workers = (os.cpu_count() or 1) * 5
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")

self._max_workers = max_workers
self._work_queue = queue.Queue()
self._threads = set()
self._shutdown = False
self._shutdown_lock = threading.Lock()
self._thread_name_prefix = (thread_name_prefix or
("ThreadPoolExecutor-%d" % next(self._counter)))

def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._shutdown:
raise RuntimeError('cannot schedule new futures after shutdown')

f = _base.Future()
w = _WorkItem(f, fn, args, kwargs)

self._work_queue.put(w)
self._adjust_thread_count()
return f
submit.__doc__ = _base.Executor.submit.__doc__

def _adjust_thread_count(self):
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)
# TODO(bquinlan): Should avoid creating new threads if there are more
# idle threads than items in the work queue.
num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = '%s_%d' % (self._thread_name_prefix or self,
num_threads)
t = threading.Thread(name=thread_name, target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue))
t.daemon = True
t.start()
self._threads.add(t)
_threads_queues[t] = self._work_queue

def shutdown(self, wait=True):
with self._shutdown_lock:
self._shutdown = True
self._work_queue.put(None)
if wait:
for t in self._threads:
t.join()
shutdown.__doc__ = _base.Executor.shutdown.__doc__
2 changes: 1 addition & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3197,7 +3197,7 @@ def test_Client_clears_references_after_restart(c, s, a, b):


def test_get_stops_work_after_error(loop):
with cluster() as (s, [a, b]):
with cluster(active_rpc_timeout=10) as (s, [a, b]):
with Client(s['address'], loop=loop) as c:
with pytest.raises(RuntimeError):
c.get({'x': (throws, 1), 'y': (sleep, 1.5)}, ['x', 'y'])
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_unsupported_arguments(loop):


def test_shutdown(loop):
with cluster() as (s, [a, b]):
with cluster(active_rpc_timeout=10) as (s, [a, b]):
with Client(s['address'], loop=loop) as c:
# shutdown(wait=True) waits for pending tasks to finish
e = c.get_executor()
Expand Down
4 changes: 0 additions & 4 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def test_dont_steal_few_saturated_tasks_many_workers(c, s, a, *rest):
assert not any(w.task_state for w in rest)


@pytest.mark.skip(reason='leaks large amount of memory')
@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10)
def test_steal_when_more_tasks(c, s, a, *rest):
s.extensions['stealing']._pc.callback_time = 20
Expand All @@ -326,7 +325,6 @@ def test_steal_when_more_tasks(c, s, a, *rest):
assert any(w.task_state for w in rest)


@pytest.mark.skip(reason='leaks large amount of memory')
@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 10)
def test_steal_more_attractive_tasks(c, s, a, *rest):
def slow2(x):
Expand Down Expand Up @@ -474,8 +472,6 @@ def test_restart(c, s, a, b):
assert not any(x for L in steal.stealable.values() for x in L)


# @pytest.mark.avoid_travis # leaks large amounts of memory
@pytest.mark.skip(reason='leaks memory')
@gen_cluster(client=True)
def test_steal_communication_heavy_tasks(c, s, a, b):
s.task_duration['slowadd'] = 0.001
Expand Down
22 changes: 22 additions & 0 deletions distributed/tests/test_threadpoolexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,25 @@ def f():
while all(t.is_alive() for t in threads):
sleep(0.01)
assert time() < start + 1


def test_shutdown_timeout():
e = ThreadPoolExecutor(1)
futures = [e.submit(sleep, 0.1 * i) for i in range(1, 3, 1)]
sleep(0.01)

start = time()
e.shutdown()
end = time()
assert end - start > 0.1


def test_shutdown_timeout_raises():
e = ThreadPoolExecutor(1)
futures = [e.submit(sleep, 0.1 * i) for i in range(1, 3, 1)]
sleep(0.05)

start = time()
e.shutdown(timeout=0.1)
end = time()
assert end - start > 0.05
3 changes: 2 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,9 @@ def test_spill_by_default(c, s, w):
da = pytest.importorskip('dask.array')
x = da.ones(int(TOTAL_MEMORY * 0.7), chunks=10000000, dtype='u1')
y = c.persist(x)
yield _wait(y)
yield wait(y)
assert len(w.data.slow) # something is on disk
del x, y


@gen_cluster(ncores=[('127.0.0.1', 1)],
Expand Down
17 changes: 14 additions & 3 deletions distributed/threadpoolexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
"""
from __future__ import print_function, division, absolute_import

from concurrent.futures import thread
from . import _concurrent_futures_thread as thread
import logging
import threading

from .compatibility import get_thread_identity
from .metrics import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,9 +65,19 @@ def _adjust_thread_count(self):
self._threads.add(t)
t.start()

def shutdown(self, wait=True):
def shutdown(self, wait=True, timeout=None):
with threads_lock:
thread.ThreadPoolExecutor.shutdown(self, wait=wait)
with self._shutdown_lock:
self._shutdown = True
self._work_queue.put(None)
if timeout is not None:
deadline = time() + timeout
for t in self._threads:
if timeout is not None:
timeout2 = max(deadline - time(), 0)
else:
timeout2 = None
t.join(timeout=timeout2)


def secede():
Expand Down
10 changes: 5 additions & 5 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def run_nanny(q, scheduler_q, **kwargs):


@contextmanager
def check_active_rpc(loop, active_rpc_timeout=0):
def check_active_rpc(loop, active_rpc_timeout=1):
if rpc.active > 0:
# Streams from a previous test dangling around?
gc.collect()
Expand Down Expand Up @@ -339,7 +339,7 @@ def wait_a_bit():


@contextmanager
def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=0,
def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1,
scheduler_kwargs={}, should_check_state=True):
ws = weakref.WeakSet()
before = process_state()
Expand Down Expand Up @@ -522,7 +522,7 @@ def check_state(before, after):
break

start = time()
while after['used-memory'] > before['used-memory'] + 1e8:
while after['used-memory'] > before['used-memory'] + 1e7:
gc.collect()
sleep(0.10)
after = process_state()
Expand All @@ -535,7 +535,7 @@ def check_state(before, after):
"total leaked total", (after['used-memory'] - initial_state['used-memory']) / 1e6) # , end=' ')

total_diff = after['used-memory'] - initial_state['used-memory']
assert total_diff < 3e9, total_diff
assert total_diff < 2e9, total_diff


@gen.coroutine
Expand Down Expand Up @@ -592,7 +592,7 @@ def iscoroutinefunction(f):
def gen_cluster(ncores=[('127.0.0.1', 1), ('127.0.0.1', 2)],
scheduler='127.0.0.1', timeout=10, security=None,
Worker=Worker, client=False, scheduler_kwargs={},
worker_kwargs={}, active_rpc_timeout=0, should_check_state=True):
worker_kwargs={}, active_rpc_timeout=1, should_check_state=True):
from distributed import Client
""" Coroutine test with small cluster

Expand Down
5 changes: 4 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def _close(self, report=True, timeout=10, nanny=True):
self.scheduler.unregister(address=self.address),
io_loop=self.loop)
self.scheduler.close_rpc()
self.executor.shutdown(wait=False)
if isinstance(self.executor, ThreadPoolExecutor):
self.executor.shutdown(timeout=timeout)
else:
self.executor.shutdown(wait=False)
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://flake8.readthedocs.io/en/latest/user/error-codes.html

# Note: there cannot be spaces after comma's here
exclude = __init__.py
exclude = __init__.py,distributed/_concurrent_futures_thread.py
ignore =
# Extra space in brackets
E20,
Expand Down