From fddb0e19b4bd65db0ff4c6eb2876b604e34afb41 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jul 2017 14:22:07 -0400 Subject: [PATCH 1/5] Make ThreadPoolExecutor threadsafe to secede operation --- distributed/threadpoolexecutor.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index efd7aad9a37..ba9d0ce460a 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -24,13 +24,13 @@ from concurrent.futures import thread import logging -from threading import local, Thread +import threading from .compatibility import get_thread_identity logger = logging.getLogger(__name__) -thread_state = local() +thread_state = threading.local() def _worker(executor, work_queue): @@ -57,23 +57,31 @@ def _worker(executor, work_queue): class ThreadPoolExecutor(thread.ThreadPoolExecutor): def _adjust_thread_count(self): if len(self._threads) < self._max_workers: - t = Thread(target=_worker, - name="ThreadPool worker %d" % len(self._threads,), - args=(self, self._work_queue)) + t = threading.Thread(target=_worker, + name="ThreadPool worker %d" % len(self._threads,), + args=(self, self._work_queue)) t.daemon = True self._threads.add(t) t.start() + def shutdown(self): + with threads_lock: + thread.ThreadPoolExecutor.shutdown(self) + def secede(): """ Have this thread secede from the ThreadPoolExecutor """ thread_state.proceed = False ident = get_thread_identity() - for t in list(thread_state.executor._threads): - if t.ident == ident: - thread_state.executor._threads.remove(t) - break + with threads_lock: + for t in list(thread_state.executor._threads): + if t.ident == ident: + thread_state.executor._threads.remove(t) + break + thread_state.executor._adjust_thread_count() + +threads_lock = threading.Lock() """ PSF LICENSE AGREEMENT FOR PYTHON 3.5.2 From 84c586e386fa057078fad90be1528d0413bbe49b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jul 2017 14:23:35 -0400 Subject: [PATCH 2/5] Remove channels These weren't heavily used and where they were used they are now generally replacable with Queues, which are cleaner. --- distributed/channels.py | 297 ----------------------- distributed/client.py | 8 +- distributed/scheduler.py | 3 +- distributed/tests/py3_test_asyncio.py | 40 --- distributed/tests/test_channels.py | 231 ------------------ distributed/tests/test_tls_functional.py | 19 +- docs/source/channels.rst | 149 ------------ docs/source/index.rst | 1 - 8 files changed, 16 insertions(+), 732 deletions(-) delete mode 100644 distributed/channels.py delete mode 100644 distributed/tests/test_channels.py delete mode 100644 docs/source/channels.rst diff --git a/distributed/channels.py b/distributed/channels.py deleted file mode 100644 index 0ed6d395150..00000000000 --- a/distributed/channels.py +++ /dev/null @@ -1,297 +0,0 @@ -from __future__ import print_function, division, absolute_import - -from collections import deque -import logging -from time import sleep -import threading -import warnings - -from .client import Future -from .core import CommClosedError -from .utils import tokey, log_errors - -logger = logging.getLogger(__name__) - - -class ChannelScheduler(object): - """ A plugin for the scheduler to manage channels - - This adds the following routes to the scheduler - - * channel-subscribe - * channel-unsubsribe - * channel-append - """ - def __init__(self, scheduler): - self.scheduler = scheduler - self.deques = dict() - self.counts = dict() - self.clients = dict() - self.stopped = dict() - - handlers = {'channel-subscribe': self.subscribe, - 'channel-unsubscribe': self.unsubscribe, - 'channel-append': self.append, - 'channel-stop': self.stop} - - self.scheduler.client_handlers.update(handlers) - self.scheduler.extensions['channels'] = self - - def subscribe(self, channel=None, client=None, maxlen=None): - logger.info("Add new client to channel, %s, %s", client, channel) - if channel not in self.deques: - logger.info("Add new channel %s", channel) - self.deques[channel] = deque(maxlen=maxlen) - self.counts[channel] = 0 - self.clients[channel] = set() - self.stopped[channel] = False - self.clients[channel].add(client) - - comm = self.scheduler.comms[client] - for type, value in self.deques[channel]: - comm.send({'op': 'channel-append', - 'type': type, - 'value': value, - 'channel': channel}) - - if self.stopped[channel]: - comm.send({'op': 'channel-stop', - 'channel': channel}) - - def unsubscribe(self, channel=None, client=None): - logger.info("Remove client from channel, %s, %s", client, channel) - self.clients[channel].remove(client) - if not self.clients[channel]: - del self.deques[channel] - del self.counts[channel] - del self.clients[channel] - del self.stopped[channel] - - def append(self, channel=None, type=None, value=None, client=None): - if self.stopped[channel]: - return - - if len(self.deques[channel]) == self.deques[channel].maxlen: - # TODO: future might still be in deque - typ, val = self.deques[channel].popleft() - if typ == 'Future': - self.scheduler.client_releases_keys(keys=[val], - client='streaming-%s' % channel) - - self.deques[channel].append((type, value)) - self.counts[channel] += 1 - self.report(channel, type, value) - - client = 'streaming-%s' % channel - if type == 'Future': - self.scheduler.client_desires_keys(keys=[value], client=client) - - def stop(self, channel=None, client=None): - self.stopped[channel] = True - logger.info("Stop channel %s", channel) - for client in list(self.clients[channel]): - try: - comm = self.scheduler.comms[client] - comm.send({'op': 'channel-stop', - 'channel': channel}) - except (KeyError, CommClosedError): - self.unsubscribe(channel, client) - - def report(self, channel, type, value): - for client in list(self.clients[channel]): - try: - comm = self.scheduler.comms[client] - comm.send({'op': 'channel-append', - 'type': type, - 'value': value, - 'channel': channel}) - except (KeyError, CommClosedError): - self.unsubscribe(channel, client) - - -class ChannelClient(object): - def __init__(self, client): - self.client = client - self.channels = dict() - self.client.extensions['channels'] = self - - handlers = {'channel-append': self.receive_key, - 'channel-stop': self.receive_stop} - - self.client._handlers.update(handlers) - - self.client.channel = self._create_channel # monkey patch - - def _create_channel(self, channel, maxlen=None): - if channel not in self.channels: - c = Channel(self.client, channel, maxlen=maxlen) - self.channels[channel] = c - return c - else: - return self.channels[channel] - - def receive_key(self, channel=None, type=None, value=None): - if channel not in self.channels: - self._create_channel(channel) - self.channels[channel]._receive_update(type, value) - - def receive_stop(self, channel=None): - self.channels[channel]._receive_stop() - - -class Channel(object): - """ - A changing stream of futures or data shared between clients - - Several clients connected to the same scheduler can communicate a sequence - of data or futures between each other through shared *channels*. All - clients can append to the channel at any time. All clients will be updated - when a channel updates. The central scheduler maintains consistency and - ordering of events. - - Channels can contain Future objects or generic data. All data should be - small and msgpack encodable (strings, numbers, lists, dicts.) Channels - should not be used to send large datasets directly. Instead scatter the - data into a Future and send that instead. - - Examples - -------- - - Create channels from your Client: - - >>> client = Client('scheduler-address:8786') # doctest: +SKIP - >>> chan = client.channel('my-channel') # doctest: +SKIP - - Append futures onto a channel - - >>> future = client.submit(add, 1, 2) # doctest: +SKIP - >>> chan.append(future) # doctest: +SKIP - - A channel maintains a collection of current futures added by both your - client, and others. - - >>> chan.data # doctest: +SKIP - deque([, - ]) - - You can iterate over a channel to get back futures. - - >>> for future in chan: # doctest: +SKIP - ... pass - - You can send small and simple data as well - - >>> chan.append({'score': 123}) # doctest: +SKIP - - To publish large amounts of data, scatter the data into a future - - >>> [future] = client.scatter([large_numpy_array]) # doctest: +SKIP - >>> chan.append(future) # doctest: +SKIP - """ - def __init__(self, client, name, maxlen=None): - self.client = client - self.name = name - self.data = deque(maxlen=maxlen) - self.stopped = False - self.count = 0 - self._pending = dict() - self._lock = threading.Lock() - self._thread_condition = threading.Condition() - - self.client._send_to_scheduler({'op': 'channel-subscribe', - 'channel': name, - 'maxlen': maxlen, - 'client': self.client.id}) - - @property - def futures(self): - warnings.warn("The .futures attribute has moved to .data") - return self.data - - def append(self, value): - """ Append a future onto the channel """ - if self.stopped: - raise StopIteration() - if isinstance(value, Future): - msg = {'op': 'channel-append', - 'channel': self.name, - 'type': 'Future', - 'value': tokey(value.key)} - self._pending[value.key] = value # hold on to reference until ack - else: - msg = {'op': 'channel-append', - 'channel': self.name, - 'type': 'value', - 'value': value} - - self.client._send_to_scheduler(msg) - - def stop(self): - if self.stopped: - return - self.client._send_to_scheduler({'op': 'channel-stop', - 'channel': self.name}) - - def _receive_update(self, type=None, value=None): - with self._lock: - self.count += 1 - if type == 'Future': - self.data.append(Future(value, self.client, inform=True)) - else: - self.data.append(value) - if type == 'Future': - if value in self._pending: - del self._pending[value] - - with self._thread_condition: - self._thread_condition.notify_all() - - def _receive_stop(self): - logger.info("Channel stopped: %s", self.name) - self.stopped = True - with self._thread_condition: - self._thread_condition.notify_all() - - def flush(self): - """ - Wait for acknowledgement from the scheduler on any pending futures - """ - while self._pending: - sleep(0.01) - - def __del__(self): - if not self.client.scheduler_comm.comm: - self.client._send_to_scheduler({'op': 'channel-unsubscribe', - 'channel': self.name, - 'client': self.client.id}) - - def __iter__(self): - with log_errors(): - with self._lock: - last = self.count - L = list(self.data) - for future in L: - yield future - - while True: - while self.count == last: - if self.stopped: - return - self._thread_condition.acquire() - self._thread_condition.wait() - self._thread_condition.release() - - with self._lock: - n = min(self.count - last, len(self.data)) - L = [self.data[i] for i in range(-n, 0)] - last = self.count - for f in L: - yield f - - def __len__(self): - return len(self.data) - - def __str__(self): - return "" % (self.name, len(self.data)) - - __repr__ = __str__ diff --git a/distributed/client.py b/distributed/client.py index dde5c741893..4488c834c14 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -488,8 +488,6 @@ def __init__(self, address=None, loop=None, timeout=5, self.start(timeout=timeout, asynchronous=asynchronous) - from distributed.channels import ChannelClient - ChannelClient(self) # registers itself on construction from distributed.recreate_exceptions import ReplayExceptionClient ReplayExceptionClient(self) @@ -920,6 +918,12 @@ def get_executor(self, **kwargs): """ return ClientExecutor(self, **kwargs) + def channel(self, *args, **kwargs): + """ Deprecated: see dask.distributed.Queue instead """ + msg = ("Channels have been removed. Consider using Queues instead. " + "http://distributed.readthedocs.io/en/latest/api.html#distributed.Queue") + raise NotImplementedError(msg) + def submit(self, func, *args, **kwargs): """ Submit a function application to the scheduler diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 79b8d54a666..1d0fba0fe12 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -38,7 +38,6 @@ from .utils_comm import (scatter_to_workers, gather_from_workers) from .versions import get_versions -from .channels import ChannelScheduler from .publish import PublishExtension from .queues import QueueExtension from .recreate_exceptions import ReplayExceptionScheduler @@ -191,7 +190,7 @@ class Scheduler(ServerNode): def __init__(self, center=None, loop=None, delete_interval=500, synchronize_worker_interval=60000, services=None, allowed_failures=ALLOWED_FAILURES, - extensions=[ChannelScheduler, PublishExtension, WorkStealing, + extensions=[PublishExtension, WorkStealing, ReplayExceptionScheduler, QueueExtension, VariableExtension], validate=False, scheduler_file=None, security=None, diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py index a99e7748ac7..438c6ee3153 100644 --- a/distributed/tests/py3_test_asyncio.py +++ b/distributed/tests/py3_test_asyncio.py @@ -179,46 +179,6 @@ async def test_asyncio_exceptions(): assert result == 10 / 2 -@coro_test -async def test_asyncio_channels(): - async with AioClient(processes=False) as c: - x = c.channel('x') - y = c.channel('y') - - assert len(x) == 0 - - while set(c.extensions['channels'].channels) != {'x', 'y'}: - await asyncio.sleep(0.01) - - xx = c.channel('x') - yy = c.channel('y') - - assert len(x) == 0 - - await asyncio.sleep(0.1) - assert set(c.extensions['channels'].channels) == {'x', 'y'} - - future = c.submit(inc, 1) - - x.append(future) - - while not x.data: - await asyncio.sleep(0.01) - - assert len(x) == 1 - - assert xx.data[0].key == future.key - - xxx = c.channel('x') - while not xxx.data: - await asyncio.sleep(0.01) - - assert xxx.data[0].key == future.key - - assert 'x' in repr(x) - assert '1' in repr(x) - - @coro_test async def test_asyncio_exception_on_exception(): async with AioClient(processes=False) as c: diff --git a/distributed/tests/test_channels.py b/distributed/tests/test_channels.py deleted file mode 100644 index 20bf8d1427c..00000000000 --- a/distributed/tests/test_channels.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import print_function, division, absolute_import - -from operator import add -from time import sleep - -import pytest -from toolz import take -from tornado import gen - -from distributed import Client -from distributed import worker_client -from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc - - -@gen_cluster(client=True) -def test_channel(c, s, a, b): - x = c.channel('x') - y = c.channel('y') - - assert len(x) == 0 - - while set(c.extensions['channels'].channels) != {'x', 'y'}: - yield gen.sleep(0.01) - - xx = c.channel('x') - yy = c.channel('y') - - assert len(x) == 0 - - yield gen.sleep(0.1) - assert set(c.extensions['channels'].channels) == {'x', 'y'} - - future = c.submit(inc, 1) - - x.append(future) - - while not x.data: - yield gen.sleep(0.01) - - assert len(x) == 1 - - assert xx.data[0].key == future.key - - xxx = c.channel('x') - while not xxx.data: - yield gen.sleep(0.01) - - assert xxx.data[0].key == future.key - - assert 'x' in repr(x) - assert '1' in repr(x) - - -def test_worker_client(loop): - def produce(n): - with worker_client() as c: - x = c.channel('x') - for i in range(n): - future = c.submit(slowinc, i, delay=0.01, key='f-%d' % i) - x.append(future) - - x.flush() - - def consume(): - with worker_client() as c: - x = c.channel('x') - y = c.channel('y') - last = 0 - for i, future in enumerate(x): - last = c.submit(add, future, last, key='add-' + future.key) - y.append(last) - - y.flush() - - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: - x = c.channel('x') - y = c.channel('y') - - producers = (c.submit(produce, 5), c.submit(produce, 10)) - consumer = c.submit(consume) - - results = [] - for i, future in enumerate(take(15, y)): - result = future.result() - results.append(result) - - assert len(results) == 15 - assert all(0 < r < 100 for r in results) - - -@gen_cluster(client=True) -def test_channel_scheduler(c, s, a, b): - chan = c.channel('chan', maxlen=5) - - x = c.submit(inc, 1) - key = x.key - chan.append(x) - del x - - while not len(chan): - yield gen.sleep(0.01) - - assert 'streaming-chan' in s.who_wants[key] - assert s.wants_what['streaming-chan'] == {key} - - while len(s.who_wants[key]) < 2: - yield gen.sleep(0.01) - - assert s.wants_what[c.id] == {key} - - for i in range(10): - chan.append(c.submit(inc, i)) - - start = time() - while True: - if len(chan) == len(s.task_state) == 5: - break - else: - assert time() < start + 2 - yield gen.sleep(0.01) - - results = yield c.gather(list(chan.data)) - assert results == [6, 7, 8, 9, 10] - - -@gen_cluster(client=True) -def test_multiple_maxlen(c, s, a, b): - c2 = yield Client((s.ip, s.port), asynchronous=True) - - x = c.channel('x', maxlen=10) - assert x.data.maxlen == 10 - x2 = c2.channel('x', maxlen=20) - assert x2.data.maxlen == 20 - - for i in range(10): - x.append(c.submit(inc, i)) - - while len(s.wants_what[c2.id]) < 10: - yield gen.sleep(0.01) - - for i in range(10, 20): - x.append(c.submit(inc, i)) - - while len(x2) < 20: - yield gen.sleep(0.01) - - yield gen.sleep(0.1) - - assert len(x2) == 20 # They stay this long after a delay - assert len(s.task_state) == 20 - - yield c2.shutdown() - - -def test_stop(loop): - def produce(n): - with worker_client() as c: - x = c.channel('x') - for i in range(n): - future = c.submit(slowinc, i, delay=0.01, key='f-%d' % i) - x.append(future) - - x.stop() - x.flush() - - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - - producer = c.submit(produce, 5) - - futures = list(x) - assert len(futures) == 5 - - with pytest.raises(StopIteration): - x.append(c.submit(inc, 1)) - - with Client(s['address']) as c2: - xx = c2.channel('x') - futures = list(xx) - assert len(futures) == 5 - - -@gen_cluster(client=True) -def test_values(c, s, a, b): - c2 = yield Client((s.ip, s.port), asynchronous=True) - - x = c.channel('x') - x2 = c2.channel('x') - - data = [123, 'Hello!', {'x': [1, 2, 3]}] - for item in data: - x.append(item) - - while len(x2.data) < 3: - yield gen.sleep(0.01) - - assert list(x2.data) == data - - yield c2.shutdown() - - -def test_channel_gets_updates_immediately(loop): - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - future = c.submit(inc, 1) - x.append(future) - x.flush() - - with Client(s['address']) as c: - x = c.channel('x') - future = next(iter(x)) - assert future.result() == 2 - - -def test_channel_gets_updates_immediately_2(loop): - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - - with Client(s['address']) as c2: - x2 = c.channel('x') - future = c2.submit(inc, 1) - x2.append(future) - x2.flush() - - future = next(iter(x)) - assert future.result() == 2 diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 7ea057a6337..fe0d017f383 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -12,7 +12,7 @@ from toolz import take from tornado import gen -from distributed import Nanny, worker_client +from distributed import Nanny, worker_client, Queue from distributed.client import _wait from distributed.metrics import time from distributed.utils_test import (gen_cluster, tls_only_security, @@ -26,22 +26,21 @@ def gen_tls_cluster(**kwargs): @gen_tls_cluster(client=True) -def test_channel(c, s, a, b): +def test_Queue(c, s, a, b): assert s.address.startswith('tls://') - x = c.channel('x') - y = c.channel('y') + x = Queue('x') + y = Queue('y') - assert len(x) == 0 + size = yield x.qsize() + assert size == 0 future = c.submit(inc, 1) - x.append(future) + yield x.put(future) - while not x.data: - yield gen.sleep(0.01) - - assert len(x) == 1 + future2 = yield x.get() + assert future.key == future2.key @gen_tls_cluster(client=True, timeout=None) diff --git a/docs/source/channels.rst b/docs/source/channels.rst deleted file mode 100644 index ad1ff1bb42e..00000000000 --- a/docs/source/channels.rst +++ /dev/null @@ -1,149 +0,0 @@ -Shared Futures With Channels -============================ - -A channel is a changing stream of data or futures shared between any number of -clients connected to the same scheduler. - -Any client can push msgpack-encodable data (numbers, strings, lists, dicts) or -Future objects into a channel. All other clients listening to the channel will -receive that data or future as a linear sequence. Communication happens -through the scheduler, which linearizes everything. Channels should only be -used to move small amounts of administrative data. For larger data, create -futures and send the futures instead. - - -Examples --------- - -Basic Usage -~~~~~~~~~~~ - -Create channels from your Client: - -.. code-block:: python - - >>> client = Client('scheduler-address:8786') - >>> chan = client.channel('my-channel') - -Append futures or data onto a channel - -.. code-block:: python - - >>> future = client.submit(add, 1, 2) - >>> chan.append(future) - >>> chan.append({'x': 123}) - -A channel maintains a collection of current futures added by both your -client, and others. - -.. code-block:: python - - >>> chan.data - deque([, - {'x': 123}]) - -If you wish to persist the current status of tasks outside of the distributed -cluster (e.g. to take a snapshot in case of shutdown) you can copy a channel full -of data as it is a python deque_. - -.. _deque: https://docs.python.org/3.5/library/collections.html#collections.deque` - -.. code-block:: python - - channelcopy = list(chan.data) - -You can iterate over a channel to get back data. - -.. code-block:: python - - >>> anotherclient = Client('scheduler-address:8786') - >>> chan = anotherclient.channel('my-channel') - >>> for element in chan: - ... pass - -When done writing, call flush to wait until your appends have been -fully registered with the scheduler. - -.. code-block:: python - - >>> client = Client('scheduler-address:8786') - >>> chan = client.channel('my-channel') - >>> future2 = client.submit(time.sleep, 2) - >>> chan.append(future2) - >>> chan.flush() - - -Example with worker_client -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Using channels with `worker_client`_ allows for a more decoupled version -of what is possible with :doc:`Data Streams with Queues` -in that independent worker clients can build up a set of results -which can be read later by a different client. -This opens up Dask/Distributed to being integrated in a wider application -environment similar to other python task queues such as Celery_. - -.. _worker_client: http://distributed.readthedocs.io/en/latest/task-launch.html#submit-tasks-from-worker -.. _Celery: http://www.celeryproject.org/ - -.. code-block:: python - - import random, time, operator - from distributed import Client, worker_client - from time import sleep - - def emit(name): - with worker_client() as c: - chan = c.channel(name) - while True: - future = c.submit(random.random, pure=False) - chan.append(future) - sleep(1) - - def combine(): - with worker_client() as c: - a_chan = c.channel('a') - b_chan = c.channel('b') - out_chan = c.channel('adds') - for a, b in zip(a_chan, b_chan): - future = c.submit(operator.add, a, b) - out_chan.append(future) - - client = Client() - - emitters = (client.submit(emit, 'a'), client.submit(emit, 'b')) - combiner = client.submit(combine) - chan = client.channel('adds') - - - for future in chan: - print(future.result()) - ...: - 1.782009416831722 - ... - -All iterations on a channel by different clients can be stopped using the ``stop`` method - -.. code-block:: python - - chan.stop() - - -Additional Applications ------------------------ - -Channels can serve as a coordination point or semaphore. They can signal -stopping criteria for iterative processes. - -Short lived clients, such as occur when firing off controlling tasks from a web -application, AWS Lambda, or other fire and forget script, often need a place to -store their futures so that in-flight work doesn't get garbage collected. -Because channels act as clients for the purpose of garbage collection (all -futures within a Channel are considered desired) they can serve as this -repository after short-lived clients die off. - -Worker clients can communicate large amounts of data to each other using -channels by first scattering local data to themselves, creating futures, and -then pushing those futures down a shared channel. When subscribers to the -channel gather these futures they will engage the normal high-bandwidth -inter-worker communication mechanism. diff --git a/docs/source/index.rst b/docs/source/index.rst index ffb6b520c1d..55845e2aacf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -104,7 +104,6 @@ Contents adaptive asynchronous - channels configuration ec2 local-cluster From 4967568e3c216b9d9d109c49b4a6a6b2c3465950 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jul 2017 14:28:27 -0400 Subject: [PATCH 3/5] Add close method to client --- distributed/client.py | 17 +++++++++++++++++ distributed/tests/test_client.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index 4488c834c14..4dc97d2890b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -835,6 +835,23 @@ def _handle_error(self, exception=None): logger.warning("Scheduler exception:") logger.exception(exception) + @gen.coroutine + def _close(self): + with log_errors(): + if self.status == 'closed': + raise gen.Return() + if self.status == 'running': + self._send_to_scheduler({'op': 'close-stream'}) + with ignoring(AttributeError): + yield self.scheduler_comm.close() + with ignoring(AttributeError): + self.scheduler.close_rpc() + self.status = 'closed' + + def close(self): + """ Close this client and its connection to the scheduler """ + return self.sync(self._close) + @gen.coroutine def _shutdown(self, fast=False): """ Send shutdown signal and wait until scheduler completes """ diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c576ba8563d..6da71558ae7 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4226,6 +4226,20 @@ def test_quiet_client_shutdown(loop): assert not out +@gen_cluster() +def test_close(s, a, b): + c = yield Client(s.address, asynchronous=True) + future = c.submit(inc, 1) + yield wait(future) + assert c.id in s.wants_what + yield c.close() + + start = time() + while c.id in s.wants_what or s.task_state: + yield gen.sleep(0.01) + assert time() < start + 5 + + def test_threadsafe(loop): with cluster() as (s, [a, b]): with Client(s['address'], loop=loop) as c: From f389ac68993f96cb49fc1f7d6bf807f8c2763b47 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jul 2017 14:29:47 -0400 Subject: [PATCH 4/5] Add metadata to dataframe tests This caused unpleasant warnings in tests otherwise --- distributed/tests/test_collections.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index aaf8decb2f8..71d0fc0894c 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -163,20 +163,20 @@ def test_dataframe_groupby_tasks(loop): with dask.set_options(get=c.get): for ind in [lambda x: 'A', lambda x: x.A]: a = df.groupby(ind(df)).apply(len) - b = ddf.groupby(ind(ddf)).apply(len) + b = ddf.groupby(ind(ddf)).apply(len, meta=int) assert_equal(a, b.compute(get=dask.get).sort_index()) assert not any('partd' in k[0] for k in b.dask) a = df.groupby(ind(df)).B.apply(len) - b = ddf.groupby(ind(ddf)).B.apply(len) + b = ddf.groupby(ind(ddf)).B.apply(len, meta=('B', int)) assert_equal(a, b.compute(get=dask.get).sort_index()) assert not any('partd' in k[0] for k in b.dask) with pytest.raises(NotImplementedError): - ddf.groupby(ddf[['A', 'B']]).apply(len) + ddf.groupby(ddf[['A', 'B']]).apply(len, meta=int) a = df.groupby(['A', 'B']).apply(len) - b = ddf.groupby(['A', 'B']).apply(len) + b = ddf.groupby(['A', 'B']).apply(len, meta=int) assert_equal(a, b.compute(get=dask.get).sort_index()) From fe2adcb6f7871a8b92c720852bd1289b5c4884c8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 5 Jul 2017 14:37:49 -0400 Subject: [PATCH 5/5] Add get_client and secede functions These separate and optionally replace worker_client. Additionally this required additional logic to handling default clients and workers, serialization of futures, queues, and variables. This created functionality and tests that also triggered changes to how the scheduler tracks long-running tasks. --- distributed/__init__.py | 2 +- distributed/client.py | 133 ++++++++++---- distributed/queues.py | 7 +- distributed/scheduler.py | 50 +++++- distributed/stealing.py | 5 - distributed/tests/test_client.py | 175 +++++++++++++++++-- distributed/tests/test_steal.py | 2 +- distributed/tests/test_threadpoolexecutor.py | 2 - distributed/tests/test_worker.py | 84 ++++++++- distributed/tests/test_worker_client.py | 35 +++- distributed/variable.py | 7 +- distributed/worker.py | 155 +++++++++++++++- distributed/worker_client.py | 135 +------------- docs/source/api.rst | 4 + 14 files changed, 584 insertions(+), 212 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index a9fc668e106..52cb8abc210 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -12,7 +12,7 @@ from .scheduler import Scheduler from .utils import sync from .variable import Variable -from .worker import Worker, get_worker +from .worker import Worker, get_worker, get_client, secede from .worker_client import local_client, worker_client from ._version import get_versions diff --git a/distributed/client.py b/distributed/client.py index 4dc97d2890b..ee2268a1611 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -45,7 +45,8 @@ from .protocol import to_serialize from .protocol.pickle import dumps, loads from .security import Security -from .worker import dumps_task +from .sizeof import sizeof +from .worker import dumps_task, thread_state, get_client from .utils import (All, sync, funcname, ignoring, queue_to_iterator, tokey, log_errors, str_graph, key_split, format_bytes) from .versions import get_versions @@ -269,10 +270,11 @@ def release(self): self.client._dec_ref(tokey(self.key)) def __getstate__(self): - return self.key + return (self.key, self.client.scheduler.address) - def __setstate__(self, key): - c = default_client() + def __setstate__(self, state): + key, address = state + c = get_client(address) Future.__init__(self, key, c) c._send_to_scheduler({'op': 'update-graph', 'tasks': {}, 'keys': [tokey(self.key)], 'client': c.id}) @@ -430,7 +432,7 @@ def __init__(self, address=None, loop=None, timeout=5, self.futures = dict() self.refcount = defaultdict(lambda: 0) self.coroutines = [] - self.id = type(self).__name__ + '-' + str(uuid.uuid1()) + self.id = type(self).__name__ + '-' + str(uuid.uuid1(clock_seq=os.getpid())) self.generation = 0 self.status = None self._pending_msg_buffer = [] @@ -441,7 +443,7 @@ def __init__(self, address=None, loop=None, timeout=5, assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args('client') self._connecting_to_scheduler = False - self.asynchronous = asynchronous + self._asynchronous = asynchronous self._loop_thread = None self.scheduler = None self._lock = threading.Lock() @@ -486,11 +488,34 @@ def __init__(self, address=None, loop=None, timeout=5, super(Client, self).__init__(connection_args=self.connection_args, io_loop=self.loop) - self.start(timeout=timeout, asynchronous=asynchronous) + self.start(timeout=timeout) from distributed.recreate_exceptions import ReplayExceptionClient ReplayExceptionClient(self) + @property + def asynchronous(self): + """ Are we running in the event loop? + + This is true if the user signaled that we might be when creating the + client as in the following:: + + client = Client(asynchronous=True) + + However, we override this expectation if we can definitively tell that + we are running from a thread that is not the event loop. This is + common when calling get_client() from within a worker task. Even + though the client was originally created in asynchronous mode we may + find ourselves in contexts when it is better to operate synchronously. + """ + result = self._asynchronous + try: + if get_thread_identity() != self.loop._thread_ident: + result = False + except AttributeError: # AsyncIOLoop doesn't have _thread_ident + pass + return result + def sync(self, func, *args, **kwargs): if self.asynchronous: callback_timeout = kwargs.pop('callback_timeout', None) @@ -563,11 +588,12 @@ def _repr_html_(self): else: return text - def start(self, asynchronous=None, **kwargs): + def start(self, **kwargs): """ Start scheduler running in separate thread """ if self._loop_thread is not None: return - if not asynchronous and not self.loop._running: + + if not self.asynchronous and not self.loop._running: self._loop_thread = threading.Thread(target=self.loop.start, name="Client loop") self._loop_thread.daemon = True @@ -576,14 +602,16 @@ def start(self, asynchronous=None, **kwargs): self._should_close_loop = True while not self.loop._running: sleep(0.001) + pc = PeriodicCallback(lambda: None, 1000, io_loop=self.loop) self.loop.add_callback(pc.start) _set_global_client(self) - if asynchronous: + self.status = 'connecting' + + if self.asynchronous: self._started = self._start(**kwargs) else: sync(self.loop, self._start, **kwargs) - self.status = 'running' def __await__(self): return self._started.__await__() @@ -1105,7 +1133,7 @@ def map(self, func, *iterables, **kwargs): keys = [key + '-' + tokenize(func, kwargs, *args) for args in zip(*iterables)] else: - uid = str(uuid.uuid4()) + uid = str(uuid.uuid1()) keys = [key + '-' + uid + '-' + str(i) for i in range(min(map(len, iterables)))] if iterables else [] @@ -1151,7 +1179,7 @@ def map(self, func, *iterables, **kwargs): return [futures[tokey(k)] for k in keys] @gen.coroutine - def _gather(self, futures, errors='raise', direct=False): + def _gather(self, futures, errors='raise', direct=False, local_worker=None): futures2, keys = unpack_remotedata(futures, byte_keys=True) keys = [tokey(key) for key in keys] bad_data = dict() @@ -1196,20 +1224,29 @@ def wait(k): bad_data[key] = None else: raise ValueError("Bad value, `errors=%s`" % errors) + keys = [k for k in keys if k not in bad_keys] - if direct: + data = {} + + if local_worker: # look inside local worker + data.update({k: local_worker.data[k] + for k in keys + if k in local_worker.data}) + keys = [k for k in keys if k not in data] + + if direct or local_worker: # gather directly from workers who_has = yield self.scheduler.who_has(keys=keys) - data, missing_keys, missing_workers = yield gather_from_workers( + data2, missing_keys, missing_workers = yield gather_from_workers( who_has, rpc=self.rpc, close=False) - response = {'status': 'OK', 'data': data} + response = {'status': 'OK', 'data': data2} if missing_keys: - keys2 = [key for key in keys if key not in data] + keys2 = [key for key in keys if key not in data2] response = yield self.scheduler.gather(keys=keys2) if response['status'] == 'OK': - response['data'].update(data) + response['data'].update(data2) - else: + else: # ask scheduler to gather data for us response = yield self.scheduler.gather(keys=keys) if response['status'] == 'error': @@ -1225,7 +1262,7 @@ def wait(k): if bad_data and errors == 'skip' and isinstance(futures2, list): futures2 = [f for f in futures2 if f not in bad_data] - data = response['data'] + data.update(response['data']) result = pack_data(futures2, merge(data, bad_data)) raise gen.Return(result) @@ -1296,11 +1333,16 @@ def gather(self, futures, errors='raise', maxsize=0, direct=False): return (self.gather(f, errors=errors, direct=direct) for f in futures) else: + if hasattr(thread_state, 'execution_state'): # within worker task + local_worker = thread_state.execution_state['worker'] + else: + local_worker = None return self.sync(self._gather, futures, errors=errors, - direct=direct) + direct=direct, local_worker=local_worker) @gen.coroutine - def _scatter(self, data, workers=None, broadcast=False, direct=False): + def _scatter(self, data, workers=None, broadcast=False, direct=False, + local_worker=None): if isinstance(workers, six.string_types): workers = [workers] if isinstance(data, dict) and not all(isinstance(k, (bytes, unicode)) @@ -1326,24 +1368,34 @@ def _scatter(self, data, workers=None, broadcast=False, direct=False): assert isinstance(data, dict) - data2 = valmap(to_serialize, data) types = valmap(type, data) - if direct: - ncores = yield self.scheduler.ncores(workers=workers) - if not ncores: - raise ValueError("No valid workers") - _, who_has, nbytes = yield scatter_to_workers(ncores, data2, - report=False, - rpc=self.rpc) + if local_worker: # running within task + local_worker.update_data(data=data, report=False) + + yield self.scheduler.update_data( + who_has={key: [local_worker.address] for key in data}, + nbytes=valmap(sizeof, data), + client=self.id) - yield self.scheduler.update_data(who_has=who_has, nbytes=nbytes) else: - yield self.scheduler.scatter(data=data2, workers=workers, - client=self.id, - broadcast=broadcast) + data2 = valmap(to_serialize, data) + if direct: + ncores = yield self.scheduler.ncores(workers=workers) + if not ncores: + raise ValueError("No valid workers") + + _, who_has, nbytes = yield scatter_to_workers(ncores, data2, + report=False, + rpc=self.rpc) + + yield self.scheduler.update_data(who_has=who_has, nbytes=nbytes) + else: + yield self.scheduler.scatter(data=data2, workers=workers, + client=self.id, + broadcast=broadcast) - out = {k: self._Future(k, self) for k in data2} + out = {k: self._Future(k, self) for k in data} for key, typ in types.items(): self.futures[key].finish(type=typ) @@ -1459,8 +1511,13 @@ def scatter(self, data, workers=None, broadcast=False, direct=False, maxsize=0): else: return queue_to_iterator(qout) else: + if hasattr(thread_state, 'execution_state'): # inside worker task + local_worker = thread_state.execution_state['worker'] + else: + local_worker = None return self.sync(self._scatter, data, workers=workers, - broadcast=broadcast, direct=direct) + broadcast=broadcast, direct=direct, + local_worker=local_worker) @gen.coroutine def _cancel(self, futures): @@ -1764,8 +1821,8 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, 'restrictions': restrictions or {}, 'loose_restrictions': loose_restrictions, 'priority': priority, - 'resources': resources}) - + 'resources': resources, + 'submitting_task': getattr(thread_state, 'key', None)}) return futures def get(self, dsk, keys, restrictions=None, loose_restrictions=None, diff --git a/distributed/queues.py b/distributed/queues.py index d633181bda0..bb2b617571d 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -14,6 +14,7 @@ from .client import Future, _get_global_client, Client from .utils import tokey, sync +from .worker import get_client logger = logging.getLogger(__name__) @@ -240,7 +241,9 @@ def __getstate__(self): def __setstate__(self, state): name, address = state - client = _get_global_client() - if client is None or client.scheduler.address != address: + try: + client = get_client(address) + assert client.address == address + except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1d0fba0fe12..1a10358a29f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -296,7 +296,8 @@ def __init__(self, center=None, loop=None, 'task-erred': self.handle_task_erred, 'release': self.handle_missing_data, 'release-worker-data': self.release_worker_data, - 'add-keys': self.add_keys} + 'add-keys': self.add_keys, + 'long-running': self.handle_long_running} self.client_handlers = {'update-graph': self.update_graph, 'client-desires-keys': self.client_desires_keys, @@ -642,7 +643,8 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, def update_graph(self, client=None, tasks=None, keys=None, dependencies=None, restrictions=None, priority=None, - loose_restrictions=None, resources=None): + loose_restrictions=None, resources=None, + submitting_task=None): """ Add new computations to the internal dask graph @@ -699,10 +701,17 @@ def update_graph(self, client=None, tasks=None, keys=None, recommendations = OrderedDict() new_priority = priority or order(tasks) # TODO: define order wrt old graph - self.generation += 1 # older graph generations take precedence + if submitting_task: # sub-tasks get better priority than parent tasks + try: + generation = self.priority[submitting_task][0] - 0.01 + except KeyError: # super-task already cleaned up + generation = self.generation + else: + self.generation += 1 # older graph generations take precedence + generation = self.generation for key in set(new_priority) & touched: if key not in self.priority: - self.priority[key] = (self.generation, new_priority[key]) # prefer old + self.priority[key] = (generation, new_priority[key]) # prefer old if restrictions: # *restrictions* is a dict keying task ids to lists of @@ -1293,6 +1302,37 @@ def release_worker_data(self, stream=None, keys=None, worker=None): if recommendations: self.transitions(recommendations) + def handle_long_running(self, key=None, worker=None, compute_duration=None): + """ A task has seceded from the thread pool + + We stop the task from being stolen in the future, and change task + duration accounting as if the task has stopped. + """ + self.extensions['stealing'].remove_key_from_stealable(key) + + try: + actual_worker = self.rprocessing[key] + except KeyError: + logger.debug("Received long-running signal from duplicate task. " + "Ignoring.") + return + + if compute_duration: + ks = key_split(key) + old_duration = self.task_duration.get(ks, 0) + new_duration = compute_duration + if not old_duration: + avg_duration = new_duration + else: + avg_duration = (0.5 * old_duration + + 0.5 * new_duration) + + self.task_duration[ks] = avg_duration + + worker = self.rprocessing[key] + self.occupancy[actual_worker] -= self.processing[actual_worker][key] + self.processing[actual_worker][key] = 0 + @gen.coroutine def handle_worker(self, worker): """ @@ -2307,7 +2347,7 @@ def transition_processing_memory(self, key, nbytes=None, type=None, ############################# # Update Timing Information # ############################# - if compute_start: + if compute_start and self.processing[worker].get(key, True): # Update average task duration for worker info = self.worker_info[worker] ks = key_split(key) diff --git a/distributed/stealing.py b/distributed/stealing.py index 79f371d2aad..b97baf9d766 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -51,8 +51,6 @@ def __init__(self, scheduler): self.scheduler.events['stealing'] = deque(maxlen=100000) self.count = 0 - scheduler.worker_handlers['long-running'] = self.transition_long_running - @property def log(self): return self.scheduler.events['stealing'] @@ -80,9 +78,6 @@ def transition(self, key, start, finish, compute_start=None, if self.scheduler.task_state[k] == 'processing': self.put_key_in_stealable(k, split=ks) - def transition_long_running(self, key=None, worker=None): - self.remove_key_from_stealable(key) - def put_key_in_stealable(self, key, split=None): worker = self.scheduler.rprocessing[key] cost_multiplier, level = self.steal_time_ratio(key, split=split) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6da71558ae7..461f7dc7dbe 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -11,6 +11,7 @@ import pickle import random import sys +import threading from threading import Thread, Semaphore from time import sleep import traceback @@ -27,7 +28,8 @@ import dask from dask import delayed from dask.context import _globals -from distributed import Worker, Nanny, recreate_exceptions, fire_and_forget +from distributed import (Worker, Nanny, recreate_exceptions, fire_and_forget, + get_client, secede, get_worker) from distributed.comm import CommClosedError from distributed.utils_comm import WrappedKey from distributed.client import (Client, Future, _wait, @@ -4253,9 +4255,9 @@ def f(_): total = c.submit(sum, list(d)) return total.result() - from multiprocessing.pool import ThreadPool - pool = ThreadPool(20) - results = pool.map(f, range(20)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(20) + results = list(e.map(f, range(20))) assert results and all(results) @@ -4272,9 +4274,9 @@ def f(_): sleep(0.001) return total - from multiprocessing.pool import ThreadPool - pool = ThreadPool(30) - results = pool.map(f, range(30)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) + results = list(e.map(f, range(30))) assert results and all(results) @@ -4292,9 +4294,9 @@ def f(_): sleep(0.001) return total - from multiprocessing.pool import ThreadPool - pool = ThreadPool(30) - results = pool.map(f, range(30)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) + results = list(e.map(f, range(30))) assert results and all(results) @@ -4306,6 +4308,159 @@ def test_identity(c, s, a, b): assert s.id.lower().startswith('scheduler') +@gen_cluster(client=True, ncores=[('127.0.0.1', 4)] * 2) +def test_get_client(c, s, a, b): + assert get_client() is c + assert c.asynchronous + def f(x): + client = get_client() + future = client.submit(inc, x) + import distributed + assert not client.asynchronous + assert client is distributed.tmp_client + return future.result() + + import distributed + distributed.tmp_client = c + try: + futures = c.map(f, range(5)) + results = yield c.gather(futures) + assert results == list(map(inc, range(5))) + finally: + del distributed.tmp_client + + +@gen_cluster(client=True) +def test_serialize_collections(c, s, a, b): + da = pytest.importorskip('dask.array') + x = da.arange(10, chunks=(5,)).persist() + + def f(x): + assert isinstance(x, da.Array) + return x.sum().compute() + + future = c.submit(f, x) + result = yield future + assert result == sum(range(10)) + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 1, timeout=100) +def test_secede_simple(c, s, a): + def f(): + client = get_client() + secede() + return client.submit(inc, 1).result() + + result = yield c.submit(f) + assert result == 2 + + +@slow +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2, timeout=60) +def test_secede_balances(c, s, a, b): + def f(x): + client = get_client() + sleep(0.01) # do some work + secede() + futures = client.map(slowinc, range(10), pure=False, delay=0.01) + total = client.submit(sum, futures).result() + return total + + futures = c.map(f, range(100)) + start = time() + while not all(f.status == 'finished' for f in futures): + yield gen.sleep(0.01) + assert threading.active_count() < 50 + + # assert 0.005 < s.task_duration['f'] < 0.1 + assert len(a.log) < 2 * len(b.log) + assert len(b.log) < 2 * len(a.log) + + results = yield c.gather(futures) + assert results == [sum(map(inc, range(10)))] * 100 + + +@gen_cluster(client=True) +def test_sub_submit_priority(c, s, a, b): + def f(): + client = get_client() + client.submit(slowinc, 1, delay=0.2) + + future = c.submit(f) + yield gen.sleep(0.1) + if len(s.task_state) == 2: + f_key = [k for k in s.task_state if k.startswith('f')][0] + slowinc_key = [k for k in s.task_state if k.startswith('slowinc')][0] + assert s.priorities[f_key] > s.priorities[slowinc_key] # lower values schedule first + + +def test_get_client_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + results = c.run(lambda: get_worker().scheduler.address) + assert results == {w['address']: s['address'] for w in [a, b]} + + results = c.run(lambda: get_client().scheduler.address) + assert results == {w['address']: s['address'] for w in [a, b]} + + +@gen_cluster(client=True) +def test_serialize_collections_of_futures(c, s, a, b): + pd = pytest.importorskip('pandas') + dd = pytest.importorskip('dask.dataframe') + from dask.dataframe.utils import assert_eq + + df = pd.DataFrame({'x': [1, 2, 3]}) + ddf = dd.from_pandas(df, npartitions=2).persist() + future = yield c.scatter(ddf) + + ddf2 = yield future + df2 = yield c.compute(ddf2) + + assert_eq(df, df2) + + +def test_serialize_collections_of_futures_sync(loop): + pd = pytest.importorskip('pandas') + dd = pytest.importorskip('dask.dataframe') + from dask.dataframe.utils import assert_eq + + df = pd.DataFrame({'x': [1, 2, 3]}) + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + ddf = dd.from_pandas(df, npartitions=2).persist() + future = c.scatter(ddf) + + result = future.result() + assert_eq(result.compute(), df) + + assert future.type == dd.DataFrame + assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() + + +def _dynamic_workload(x, delay=0.01): + if delay == 'random': + sleep(random.random() / 2) + else: + sleep(delay) + if x > 4: + return 4 + secede() + client = get_client() + futures = client.map(_dynamic_workload, [x + i + 1 for i in range(2)], + pure=False, delay=delay) + total = client.submit(sum, futures) + return total.result() + + +@pytest.mark.parametrize('delay', [0.02, slow('random')]) +def test_dynamic_workloads_sync(loop, delay): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + future = c.submit(_dynamic_workload, 0, delay=delay) + assert future.result(timeout=40) == 52 + + @gen_cluster(client=True) def test_bytes_keys(c, s, a, b): key = b'inc-123' diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a191cfb07a4..ef6f5c7a5f9 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -548,7 +548,7 @@ def long(delay): assert time() < start + 1 na = len(a.executing) - nb= len(b.executing) + nb = len(b.executing) incs = c.map(inc, range(100), workers=a.address, allow_other_workers=True) diff --git a/distributed/tests/test_threadpoolexecutor.py b/distributed/tests/test_threadpoolexecutor.py index fe8a0a8e2c3..3b46724d34f 100644 --- a/distributed/tests/test_threadpoolexecutor.py +++ b/distributed/tests/test_threadpoolexecutor.py @@ -17,8 +17,6 @@ def f(): assert e.submit(f).result() == 1 - assert len(e._threads) == 1 - list(e.map(sleep, [0.01] * 4)) assert len(threads | e._threads) == 3 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d16ff229757..3403a6f13a8 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -18,7 +18,7 @@ from tornado.ioloop import TimeoutError import distributed -from distributed import Nanny +from distributed import Nanny, Client, get_client, wait, default_client from distributed.core import rpc, connect from distributed.client import _wait from distributed.scheduler import Scheduler @@ -29,8 +29,7 @@ from distributed.worker import Worker, error_message, logger, TOTAL_MEMORY from distributed.utils import ignoring, tmpfile from distributed.utils_test import (loop, inc, mul, gen_cluster, div, dec, - slow, slowinc, throws, gen_test, readone) - + slow, slowinc, throws, gen_test, readone, cluster) def test_worker_ncores(): @@ -806,3 +805,82 @@ def __sizeof__(self): @gen_cluster() def test_pid(s, a, b): assert s.worker_info[a.address]['pid'] == os.getpid() + + +@gen_cluster(client=True) +def test_get_client(c, s, a, b): + def f(x): + cc = get_client() + future = cc.submit(inc, x) + return future.result() + + assert default_client() is c + + future = c.submit(f, 10, workers=a.address) + result = yield future + assert result == 11 + + assert a._client + assert not b._client + + assert a._client is c + assert default_client() is c + + a_client = a._client + + for i in range(10): + yield wait(c.submit(f, i)) + + assert a._client is a_client + + +def test_get_client_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + def f(x): + cc = get_client() + future = cc.submit(inc, x) + return future.result() + + future = c.submit(f, 10) + assert future.result() == 11 + + +@gen_cluster(client=True) +def test_get_client_coroutine(c, s, a, b): + @gen.coroutine + def f(): + client = yield get_client() + future = client.submit(inc, 10) + result = yield future + raise gen.Return(result) + + results = yield c.run_coroutine(f) + assert results == {a.address: 11, + b.address: 11} + + +def test_get_client_coroutine_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + @gen.coroutine + def f(): + client = yield get_client() + future = client.submit(inc, 10) + result = yield future + raise gen.Return(result) + + results = c.run_coroutine(f) + assert results == {a['address']: 11, + b['address']: 11} + + +@gen_cluster() +def test_global_workers(s, a, b): + from distributed.worker import _global_workers + n = len(_global_workers) + w = _global_workers[-1]() + assert w is a or w is b + yield a._close() + yield b._close() + assert len(_global_workers) == n - 2 diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index cdfbef8b094..e208adf2184 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -2,15 +2,17 @@ import random from time import sleep +import warnings import dask from dask import delayed import pytest from tornado import gen -from distributed import worker_client, Client, as_completed +from distributed import worker_client, Client, as_completed, get_worker from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, double, cluster, loop +from distributed.worker import thread_state @gen_cluster(client=True) @@ -36,12 +38,13 @@ def func(x): @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) def test_scatter_from_worker(c, s, a, b): def func(): + from distributed.worker import thread_state with worker_client() as c: futures = c.scatter([1, 2, 3, 4, 5]) assert isinstance(futures, (list, tuple)) assert len(futures) == 5 - x = dict(c.worker.data) + x = dict(get_worker().data) y = {f.key: i for f, i in zip(futures, [1, 2, 3, 4, 5])} assert x == y @@ -61,7 +64,7 @@ def func(): o = object() futures = c.scatter({'x': o}) - correct &= c.worker.data['x'] is o + correct &= get_worker().data['x'] is o return correct future = c.submit(func) @@ -110,7 +113,7 @@ def func(): def test_same_loop(c, s, a, b): def f(): with worker_client() as lc: - return lc.loop is lc.worker.loop + return lc.loop is get_worker().loop future = c.submit(f) result = yield future @@ -159,11 +162,11 @@ def test_separate_thread_false(c, s, a): a.count = 0 def f(i): with worker_client(separate_thread=False) as client: - client.worker.count += 1 - assert client.worker.count <= 3 + get_worker().count += 1 + assert get_worker().count <= 3 sleep(random.random() / 40) - assert client.worker.count <= 3 - client.worker.count -= 1 + assert get_worker().count <= 3 + get_worker().count -= 1 return i futures = c.map(f, range(20)) @@ -198,3 +201,19 @@ def f(x): b2.compute() assert dask.context._globals['get'] == c.get + + +@gen_cluster(client=True) +def test_local_client_warning(c, s, a, b): + from distributed import local_client + def func(x): + with warnings.catch_warnings(record=True) as record: + with local_client() as c: + x = c.submit(inc, x) + result = x.result() + assert any("worker_client" in str(r.message) for r in record) + return result + + future = c.submit(func, 10) + result = yield future + assert result == 11 diff --git a/distributed/variable.py b/distributed/variable.py index e4310b4bd07..aaa2d53ecf7 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -15,6 +15,7 @@ from .client import Future, _get_global_client, Client from .metrics import time from .utils import tokey, log_errors +from .worker import get_client logger = logging.getLogger(__name__) @@ -198,7 +199,9 @@ def __getstate__(self): def __setstate__(self, state): name, address = state - client = _get_global_client() - if client is None or client.scheduler.address != address: + try: + client = get_client(address) + assert client.address == address + except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/worker.py b/distributed/worker.py index 0a7dafd3948..cd894cf7f13 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -8,9 +8,10 @@ from pickle import PicklingError import random import tempfile -from threading import current_thread, local +import threading import shutil import sys +import weakref from dask.core import istask from dask.compatibility import apply @@ -26,7 +27,7 @@ from .batched import BatchedSend from .comm import get_address_host, get_local_address_for from .config import config -from .compatibility import unicode +from .compatibility import unicode, get_thread_identity from .core import (error_message, CommClosedError, rpc, pingpong, coerce_to_address) from .metrics import time @@ -36,7 +37,7 @@ serialize_bytelist) from .security import Security from .sizeof import safe_sizeof as sizeof -from .threadpoolexecutor import ThreadPoolExecutor +from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import (funcname, get_ip, has_arg, _maybe_complex, log_errors, ignoring, validate_key, mp_context, import_file, silence_logging) @@ -44,7 +45,7 @@ _ncores = mp_context.cpu_count() -thread_state = local() +thread_state = threading.local() logger = logging.getLogger(__name__) @@ -66,6 +67,8 @@ PROCESSING = ('waiting', 'ready', 'constrained', 'executing', 'long-running') READY = ('ready', 'constrained') +_global_workers = [] + class WorkerBase(ServerNode): def __init__(self, scheduler_ip, scheduler_port=None, ncores=None, @@ -317,6 +320,8 @@ def _close(self, report=True, timeout=10, nanny=True): if self.status in ('closed', 'closing'): return logger.info("Stopping worker at %s", self.address) + if self._client: + yield self._client._close() self.status = 'closing' self.stop() self.heartbeat_callback.stop() @@ -341,6 +346,18 @@ def _close(self, report=True, timeout=10, nanny=True): self.rpc.close() self._closed.set() + self._remove_from_global_workers() + + def __del__(self): + self._remove_from_global_workers() + + def _remove_from_global_workers(self): + with log_errors(): + for ref in list(_global_workers): + if ref() is self: + _global_workers.remove(ref) + if ref() is None: + _global_workers.remove(ref) @gen.coroutine def terminate(self, comm, report=True): @@ -652,6 +669,7 @@ def apply_function(function, args, kwargs, execution_state, key): ------- msg: dictionary with status, result/error, timings, etc.. """ + thread_state.start_time = time() thread_state.execution_state = execution_state thread_state.key = key start = time() @@ -670,7 +688,7 @@ def apply_function(function, args, kwargs, execution_state, key): end = time() msg['start'] = start msg['stop'] = end - msg['thread'] = current_thread().ident + msg['thread'] = get_thread_identity() return msg @@ -944,6 +962,7 @@ def __init__(self, *args, **kwargs): self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.extensions = {} + self._lock = threading.Lock() self.data_needed = deque() # TODO: replace with heap? @@ -1005,9 +1024,12 @@ def __init__(self, *args, **kwargs): self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=(100000)) self.outgoing_count = 0 + self._client = None WorkerBase.__init__(self, *args, **kwargs) + _global_workers.append(weakref.ref(self)) + def __str__(self): return "<%s: %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" % ( self.__class__.__name__, self.address, self.status, @@ -1399,14 +1421,16 @@ def transition_executing_done(self, key, value=no_value): import pdb; pdb.set_trace() raise - def transition_executing_long_running(self, key): + def transition_executing_long_running(self, key, compute_duration=None): try: if self.validate: assert key in self.executing self.executing.remove(key) self.long_running.add(key) - self.batched_stream.send({'op': 'long-running', 'key': key}) + self.batched_stream.send({'op': 'long-running', + 'key': key, + 'compute_duration': compute_duration}) self.ensure_computing() except Exception as e: @@ -1415,6 +1439,10 @@ def transition_executing_long_running(self, key): import pdb; pdb.set_trace() raise + def maybe_transition_long_running(self, key, compute_duration=None): + if self.task_state.get(key) == 'executing': + self.transition(key, 'long-running', compute_duration=compute_duration) + ########################## # Gather Data from Peers # ########################## @@ -2103,6 +2131,40 @@ def story(self, *keys): for c in msg if isinstance(c, (tuple, list, set)))] + @property + def client(self): + """ Get local client attached to this worker + + If no such client exists, create one + + See Also + -------- + get_client + """ + with self._lock: + if self._client: + return self._client + + try: + from .client import default_client + client = default_client() + except ValueError: # no clients found, need to make a new one + pass + else: + if client.scheduler.address == self.scheduler.address: + self._client = client + + if not self._client: + from .client import Client + asynchronous = get_thread_identity() == self.loop._thread_ident + self._client = Client(self.scheduler.address, loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous) + if not asynchronous: + assert self._client.status == 'running' + return self._client + def get_worker(): """ Get the worker currently running this task @@ -2119,6 +2181,83 @@ def get_worker(): See Also -------- + get_client + worker_client + """ + try: + return thread_state.execution_state['worker'] + except AttributeError: + for ref in _global_workers[::-1]: + worker = ref() + if worker: + return worker + raise ValueError("No workers found") + + +def get_client(address=None): + """ Get a client while within a task + + This client connects to the same scheduler to which the worker is connected + + Examples + -------- + >>> def f(): + ... client = get_client() + ... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks + ... results = client.gather(futures) + ... return sum(results) + + >>> future = client.submit(f) # doctest: +SKIP + >>> future.result() # doctest: +SKIP + 55 + + See Also + -------- + get_worker worker_client + secede + """ + try: + worker = get_worker() + except ValueError: # could not find worker + pass + else: + if not address or worker.scheduler.address == address: + return worker.client + + from .client import _get_global_client + client = _get_global_client() # TODO: assumes the same scheduler + if client and not address or client.scheduler.address == address: + return client + elif address: + from .client import Client + return Client(address) + else: + raise ValueError("No global client found and no address provided") + + +def secede(): + """ + Have this task secede from the worker's thread pool + + This opens up a new scheduling slot and a new thread for a new task. + + Examples + -------- + >>> def mytask(x): + ... # do some work + ... client = get_client() + ... futures = client.map(...) # do some remote work + ... secede() # while that work happens, remove ourself from the pool + ... return client.gather(futures) # return gathered results + + See Also + -------- + get_client + get_worker """ - return thread_state.execution_state['worker'] + worker = get_worker() + tpe_secede() # have this thread secede from the thread pool + duration = time() - thread_state.start_time + worker.loop.add_callback(worker.maybe_transition_long_running, + thread_state.key, compute_duration=duration) diff --git a/distributed/worker_client.py b/distributed/worker_client.py index a6b9b215462..63ddfba3973 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -1,17 +1,9 @@ from __future__ import print_function, division, absolute_import from contextlib import contextmanager -from datetime import timedelta -from toolz import keymap, valmap, merge +import warnings -from dask.base import tokenize -from tornado import gen - -from .client import AllExit, Client, Future, pack_data, unpack_remotedata -from dask.compatibility import apply -from .sizeof import sizeof from .threadpoolexecutor import secede -from .utils import All, log_errors, sync, tokey, ignoring from .worker import thread_state, get_worker @@ -45,128 +37,17 @@ def worker_client(timeout=3, separate_thread=True): See Also -------- get_worker + get_client + secede """ - address = thread_state.execution_state['scheduler'] - worker = thread_state.execution_state['worker'] + worker = get_worker() if separate_thread: secede() # have this thread secede from the thread pool worker.loop.add_callback(worker.transition, thread_state.key, 'long-running') - with WorkerClient(address, loop=worker.loop, security=worker.security, - asynchronous=True) as wc: - # Make sure connection errors are bubbled to the caller - sync(wc.loop, gen.with_timeout, timedelta(seconds=timeout), wc._started) - wc.asynchronous = False - assert wc.status == 'running' - yield wc - - -local_client = worker_client - - -class WorkerClient(Client): - """ An Client designed to operate from a Worker process - - This client has had a few methods altered to make it more efficient for - working directly from the worker nodes. In particular scatter/gather first - look to the local data dictionary rather than sending data over the network - """ - def __init__(self, *args, **kwargs): - loop = kwargs.get('loop') - self.worker = get_worker() - kwargs['set_as_default'] = False - sync(loop, apply, Client.__init__, (self,) + args, kwargs) - - @gen.coroutine - def _scatter(self, data, workers=None, broadcast=False, direct=None): - """ Scatter data to local data dictionary - - Rather than send data out to the cluster we keep data local. However - we do report to the scheduler that the local worker has the scattered - data. This allows other workers to come by and steal this data if - desired. - - Keywords like ``broadcast=`` do not work, however operations like - ``.replicate`` work fine after calling scatter, which can fill in for - this functionality. - """ - with log_errors(): - if not (workers is None and broadcast is False): - raise NotImplementedError("Scatter from worker doesn't support workers or broadcast keywords") - - if isinstance(data, dict) and not all(isinstance(k, (bytes, str)) - for k in data): - d = yield self._scatter(keymap(tokey, data), workers, broadcast) - raise gen.Return({k: d[tokey(k)] for k in data}) - - if isinstance(data, type(range(0))): - data = list(data) - input_type = type(data) - names = False - unpack = False - if isinstance(data, (set, frozenset)): - data = list(data) - if not isinstance(data, (dict, list, tuple, set, frozenset)): - unpack = True - data = [data] - if isinstance(data, (list, tuple)): - names = list(map(tokenize, data)) - data = dict(zip(names, data)) - - types = valmap(type, data) - assert isinstance(data, dict) - - self.worker.update_data(data=data, report=False) - - yield self.scheduler.update_data( - who_has={key: [self.worker.address] for key in data}, - nbytes=valmap(sizeof, data), - client=self.id) - - out = {k: self._Future(k, self) for k in data} - for key, typ in types.items(): - self.futures[key].finish(type=typ) - - if issubclass(input_type, (list, tuple, set, frozenset)): - out = input_type(out[k] for k in names) - - if unpack: - assert len(out) == 1 - out = list(out.values())[0] - raise gen.Return(out) - - @gen.coroutine - def _gather(self, futures, errors='raise', direct=False): - """ - - Exactly like Client._gather, but get data directly from the local - worker data dictionary directly rather than through the scheduler. - - TODO: avoid scheduler for other communications, and assume that we can - communicate directly with the other workers. - """ - futures2, keys = unpack_remotedata(futures, byte_keys=True) - keys = [tokey(k) for k in keys] - - @gen.coroutine - def wait(k): - """ Want to stop the All(...) early if we find an error """ - yield self.futures[k].event.wait() - if self.futures[k].status != 'finished': - raise AllExit() - - with ignoring(AllExit): - yield All([wait(key) for key in keys if key in self.futures]) - - local = {k: self.worker.data[k] for k in keys - if k in self.worker.data} - - futures3 = {k: Future(k, self) for k in keys if k not in local} + yield worker.client - futures4 = pack_data(futures2, merge(local, futures3)) - if not futures3: - raise gen.Return(futures4) - result = yield Client._gather(self, futures4, errors=errors, - direct=True) - raise gen.Return(result) +def local_client(*args, **kwargs): + warnings.warn("local_client has moved to worker_client") + return worker_client(*args, **kwargs) diff --git a/docs/source/api.rst b/docs/source/api.rst index bf74fef7ea5..33623277b38 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -42,6 +42,8 @@ API .. autosummary:: worker_client get_worker + get_client + secede .. currentmodule:: distributed.recreate_exceptions @@ -141,6 +143,8 @@ Other .. autofunction:: distributed.worker_client .. autofunction:: distributed.get_worker +.. autofunction:: distributed.get_client +.. autofunction:: distributed.secede .. autoclass:: Queue :members: