diff --git a/distributed/__init__.py b/distributed/__init__.py index a9fc668e106..52cb8abc210 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -12,7 +12,7 @@ from .scheduler import Scheduler from .utils import sync from .variable import Variable -from .worker import Worker, get_worker +from .worker import Worker, get_worker, get_client, secede from .worker_client import local_client, worker_client from ._version import get_versions diff --git a/distributed/channels.py b/distributed/channels.py deleted file mode 100644 index 0ed6d395150..00000000000 --- a/distributed/channels.py +++ /dev/null @@ -1,297 +0,0 @@ -from __future__ import print_function, division, absolute_import - -from collections import deque -import logging -from time import sleep -import threading -import warnings - -from .client import Future -from .core import CommClosedError -from .utils import tokey, log_errors - -logger = logging.getLogger(__name__) - - -class ChannelScheduler(object): - """ A plugin for the scheduler to manage channels - - This adds the following routes to the scheduler - - * channel-subscribe - * channel-unsubsribe - * channel-append - """ - def __init__(self, scheduler): - self.scheduler = scheduler - self.deques = dict() - self.counts = dict() - self.clients = dict() - self.stopped = dict() - - handlers = {'channel-subscribe': self.subscribe, - 'channel-unsubscribe': self.unsubscribe, - 'channel-append': self.append, - 'channel-stop': self.stop} - - self.scheduler.client_handlers.update(handlers) - self.scheduler.extensions['channels'] = self - - def subscribe(self, channel=None, client=None, maxlen=None): - logger.info("Add new client to channel, %s, %s", client, channel) - if channel not in self.deques: - logger.info("Add new channel %s", channel) - self.deques[channel] = deque(maxlen=maxlen) - self.counts[channel] = 0 - self.clients[channel] = set() - self.stopped[channel] = False - self.clients[channel].add(client) - - comm = self.scheduler.comms[client] - for type, value in self.deques[channel]: - comm.send({'op': 'channel-append', - 'type': type, - 'value': value, - 'channel': channel}) - - if self.stopped[channel]: - comm.send({'op': 'channel-stop', - 'channel': channel}) - - def unsubscribe(self, channel=None, client=None): - logger.info("Remove client from channel, %s, %s", client, channel) - self.clients[channel].remove(client) - if not self.clients[channel]: - del self.deques[channel] - del self.counts[channel] - del self.clients[channel] - del self.stopped[channel] - - def append(self, channel=None, type=None, value=None, client=None): - if self.stopped[channel]: - return - - if len(self.deques[channel]) == self.deques[channel].maxlen: - # TODO: future might still be in deque - typ, val = self.deques[channel].popleft() - if typ == 'Future': - self.scheduler.client_releases_keys(keys=[val], - client='streaming-%s' % channel) - - self.deques[channel].append((type, value)) - self.counts[channel] += 1 - self.report(channel, type, value) - - client = 'streaming-%s' % channel - if type == 'Future': - self.scheduler.client_desires_keys(keys=[value], client=client) - - def stop(self, channel=None, client=None): - self.stopped[channel] = True - logger.info("Stop channel %s", channel) - for client in list(self.clients[channel]): - try: - comm = self.scheduler.comms[client] - comm.send({'op': 'channel-stop', - 'channel': channel}) - except (KeyError, CommClosedError): - self.unsubscribe(channel, client) - - def report(self, channel, type, value): - for client in list(self.clients[channel]): - try: - comm = self.scheduler.comms[client] - comm.send({'op': 'channel-append', - 'type': type, - 'value': value, - 'channel': channel}) - except (KeyError, CommClosedError): - self.unsubscribe(channel, client) - - -class ChannelClient(object): - def __init__(self, client): - self.client = client - self.channels = dict() - self.client.extensions['channels'] = self - - handlers = {'channel-append': self.receive_key, - 'channel-stop': self.receive_stop} - - self.client._handlers.update(handlers) - - self.client.channel = self._create_channel # monkey patch - - def _create_channel(self, channel, maxlen=None): - if channel not in self.channels: - c = Channel(self.client, channel, maxlen=maxlen) - self.channels[channel] = c - return c - else: - return self.channels[channel] - - def receive_key(self, channel=None, type=None, value=None): - if channel not in self.channels: - self._create_channel(channel) - self.channels[channel]._receive_update(type, value) - - def receive_stop(self, channel=None): - self.channels[channel]._receive_stop() - - -class Channel(object): - """ - A changing stream of futures or data shared between clients - - Several clients connected to the same scheduler can communicate a sequence - of data or futures between each other through shared *channels*. All - clients can append to the channel at any time. All clients will be updated - when a channel updates. The central scheduler maintains consistency and - ordering of events. - - Channels can contain Future objects or generic data. All data should be - small and msgpack encodable (strings, numbers, lists, dicts.) Channels - should not be used to send large datasets directly. Instead scatter the - data into a Future and send that instead. - - Examples - -------- - - Create channels from your Client: - - >>> client = Client('scheduler-address:8786') # doctest: +SKIP - >>> chan = client.channel('my-channel') # doctest: +SKIP - - Append futures onto a channel - - >>> future = client.submit(add, 1, 2) # doctest: +SKIP - >>> chan.append(future) # doctest: +SKIP - - A channel maintains a collection of current futures added by both your - client, and others. - - >>> chan.data # doctest: +SKIP - deque([, - ]) - - You can iterate over a channel to get back futures. - - >>> for future in chan: # doctest: +SKIP - ... pass - - You can send small and simple data as well - - >>> chan.append({'score': 123}) # doctest: +SKIP - - To publish large amounts of data, scatter the data into a future - - >>> [future] = client.scatter([large_numpy_array]) # doctest: +SKIP - >>> chan.append(future) # doctest: +SKIP - """ - def __init__(self, client, name, maxlen=None): - self.client = client - self.name = name - self.data = deque(maxlen=maxlen) - self.stopped = False - self.count = 0 - self._pending = dict() - self._lock = threading.Lock() - self._thread_condition = threading.Condition() - - self.client._send_to_scheduler({'op': 'channel-subscribe', - 'channel': name, - 'maxlen': maxlen, - 'client': self.client.id}) - - @property - def futures(self): - warnings.warn("The .futures attribute has moved to .data") - return self.data - - def append(self, value): - """ Append a future onto the channel """ - if self.stopped: - raise StopIteration() - if isinstance(value, Future): - msg = {'op': 'channel-append', - 'channel': self.name, - 'type': 'Future', - 'value': tokey(value.key)} - self._pending[value.key] = value # hold on to reference until ack - else: - msg = {'op': 'channel-append', - 'channel': self.name, - 'type': 'value', - 'value': value} - - self.client._send_to_scheduler(msg) - - def stop(self): - if self.stopped: - return - self.client._send_to_scheduler({'op': 'channel-stop', - 'channel': self.name}) - - def _receive_update(self, type=None, value=None): - with self._lock: - self.count += 1 - if type == 'Future': - self.data.append(Future(value, self.client, inform=True)) - else: - self.data.append(value) - if type == 'Future': - if value in self._pending: - del self._pending[value] - - with self._thread_condition: - self._thread_condition.notify_all() - - def _receive_stop(self): - logger.info("Channel stopped: %s", self.name) - self.stopped = True - with self._thread_condition: - self._thread_condition.notify_all() - - def flush(self): - """ - Wait for acknowledgement from the scheduler on any pending futures - """ - while self._pending: - sleep(0.01) - - def __del__(self): - if not self.client.scheduler_comm.comm: - self.client._send_to_scheduler({'op': 'channel-unsubscribe', - 'channel': self.name, - 'client': self.client.id}) - - def __iter__(self): - with log_errors(): - with self._lock: - last = self.count - L = list(self.data) - for future in L: - yield future - - while True: - while self.count == last: - if self.stopped: - return - self._thread_condition.acquire() - self._thread_condition.wait() - self._thread_condition.release() - - with self._lock: - n = min(self.count - last, len(self.data)) - L = [self.data[i] for i in range(-n, 0)] - last = self.count - for f in L: - yield f - - def __len__(self): - return len(self.data) - - def __str__(self): - return "" % (self.name, len(self.data)) - - __repr__ = __str__ diff --git a/distributed/client.py b/distributed/client.py index dde5c741893..ee2268a1611 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -45,7 +45,8 @@ from .protocol import to_serialize from .protocol.pickle import dumps, loads from .security import Security -from .worker import dumps_task +from .sizeof import sizeof +from .worker import dumps_task, thread_state, get_client from .utils import (All, sync, funcname, ignoring, queue_to_iterator, tokey, log_errors, str_graph, key_split, format_bytes) from .versions import get_versions @@ -269,10 +270,11 @@ def release(self): self.client._dec_ref(tokey(self.key)) def __getstate__(self): - return self.key + return (self.key, self.client.scheduler.address) - def __setstate__(self, key): - c = default_client() + def __setstate__(self, state): + key, address = state + c = get_client(address) Future.__init__(self, key, c) c._send_to_scheduler({'op': 'update-graph', 'tasks': {}, 'keys': [tokey(self.key)], 'client': c.id}) @@ -430,7 +432,7 @@ def __init__(self, address=None, loop=None, timeout=5, self.futures = dict() self.refcount = defaultdict(lambda: 0) self.coroutines = [] - self.id = type(self).__name__ + '-' + str(uuid.uuid1()) + self.id = type(self).__name__ + '-' + str(uuid.uuid1(clock_seq=os.getpid())) self.generation = 0 self.status = None self._pending_msg_buffer = [] @@ -441,7 +443,7 @@ def __init__(self, address=None, loop=None, timeout=5, assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args('client') self._connecting_to_scheduler = False - self.asynchronous = asynchronous + self._asynchronous = asynchronous self._loop_thread = None self.scheduler = None self._lock = threading.Lock() @@ -486,13 +488,34 @@ def __init__(self, address=None, loop=None, timeout=5, super(Client, self).__init__(connection_args=self.connection_args, io_loop=self.loop) - self.start(timeout=timeout, asynchronous=asynchronous) + self.start(timeout=timeout) - from distributed.channels import ChannelClient - ChannelClient(self) # registers itself on construction from distributed.recreate_exceptions import ReplayExceptionClient ReplayExceptionClient(self) + @property + def asynchronous(self): + """ Are we running in the event loop? + + This is true if the user signaled that we might be when creating the + client as in the following:: + + client = Client(asynchronous=True) + + However, we override this expectation if we can definitively tell that + we are running from a thread that is not the event loop. This is + common when calling get_client() from within a worker task. Even + though the client was originally created in asynchronous mode we may + find ourselves in contexts when it is better to operate synchronously. + """ + result = self._asynchronous + try: + if get_thread_identity() != self.loop._thread_ident: + result = False + except AttributeError: # AsyncIOLoop doesn't have _thread_ident + pass + return result + def sync(self, func, *args, **kwargs): if self.asynchronous: callback_timeout = kwargs.pop('callback_timeout', None) @@ -565,11 +588,12 @@ def _repr_html_(self): else: return text - def start(self, asynchronous=None, **kwargs): + def start(self, **kwargs): """ Start scheduler running in separate thread """ if self._loop_thread is not None: return - if not asynchronous and not self.loop._running: + + if not self.asynchronous and not self.loop._running: self._loop_thread = threading.Thread(target=self.loop.start, name="Client loop") self._loop_thread.daemon = True @@ -578,14 +602,16 @@ def start(self, asynchronous=None, **kwargs): self._should_close_loop = True while not self.loop._running: sleep(0.001) + pc = PeriodicCallback(lambda: None, 1000, io_loop=self.loop) self.loop.add_callback(pc.start) _set_global_client(self) - if asynchronous: + self.status = 'connecting' + + if self.asynchronous: self._started = self._start(**kwargs) else: sync(self.loop, self._start, **kwargs) - self.status = 'running' def __await__(self): return self._started.__await__() @@ -837,6 +863,23 @@ def _handle_error(self, exception=None): logger.warning("Scheduler exception:") logger.exception(exception) + @gen.coroutine + def _close(self): + with log_errors(): + if self.status == 'closed': + raise gen.Return() + if self.status == 'running': + self._send_to_scheduler({'op': 'close-stream'}) + with ignoring(AttributeError): + yield self.scheduler_comm.close() + with ignoring(AttributeError): + self.scheduler.close_rpc() + self.status = 'closed' + + def close(self): + """ Close this client and its connection to the scheduler """ + return self.sync(self._close) + @gen.coroutine def _shutdown(self, fast=False): """ Send shutdown signal and wait until scheduler completes """ @@ -920,6 +963,12 @@ def get_executor(self, **kwargs): """ return ClientExecutor(self, **kwargs) + def channel(self, *args, **kwargs): + """ Deprecated: see dask.distributed.Queue instead """ + msg = ("Channels have been removed. Consider using Queues instead. " + "http://distributed.readthedocs.io/en/latest/api.html#distributed.Queue") + raise NotImplementedError(msg) + def submit(self, func, *args, **kwargs): """ Submit a function application to the scheduler @@ -1084,7 +1133,7 @@ def map(self, func, *iterables, **kwargs): keys = [key + '-' + tokenize(func, kwargs, *args) for args in zip(*iterables)] else: - uid = str(uuid.uuid4()) + uid = str(uuid.uuid1()) keys = [key + '-' + uid + '-' + str(i) for i in range(min(map(len, iterables)))] if iterables else [] @@ -1130,7 +1179,7 @@ def map(self, func, *iterables, **kwargs): return [futures[tokey(k)] for k in keys] @gen.coroutine - def _gather(self, futures, errors='raise', direct=False): + def _gather(self, futures, errors='raise', direct=False, local_worker=None): futures2, keys = unpack_remotedata(futures, byte_keys=True) keys = [tokey(key) for key in keys] bad_data = dict() @@ -1175,20 +1224,29 @@ def wait(k): bad_data[key] = None else: raise ValueError("Bad value, `errors=%s`" % errors) + keys = [k for k in keys if k not in bad_keys] - if direct: + data = {} + + if local_worker: # look inside local worker + data.update({k: local_worker.data[k] + for k in keys + if k in local_worker.data}) + keys = [k for k in keys if k not in data] + + if direct or local_worker: # gather directly from workers who_has = yield self.scheduler.who_has(keys=keys) - data, missing_keys, missing_workers = yield gather_from_workers( + data2, missing_keys, missing_workers = yield gather_from_workers( who_has, rpc=self.rpc, close=False) - response = {'status': 'OK', 'data': data} + response = {'status': 'OK', 'data': data2} if missing_keys: - keys2 = [key for key in keys if key not in data] + keys2 = [key for key in keys if key not in data2] response = yield self.scheduler.gather(keys=keys2) if response['status'] == 'OK': - response['data'].update(data) + response['data'].update(data2) - else: + else: # ask scheduler to gather data for us response = yield self.scheduler.gather(keys=keys) if response['status'] == 'error': @@ -1204,7 +1262,7 @@ def wait(k): if bad_data and errors == 'skip' and isinstance(futures2, list): futures2 = [f for f in futures2 if f not in bad_data] - data = response['data'] + data.update(response['data']) result = pack_data(futures2, merge(data, bad_data)) raise gen.Return(result) @@ -1275,11 +1333,16 @@ def gather(self, futures, errors='raise', maxsize=0, direct=False): return (self.gather(f, errors=errors, direct=direct) for f in futures) else: + if hasattr(thread_state, 'execution_state'): # within worker task + local_worker = thread_state.execution_state['worker'] + else: + local_worker = None return self.sync(self._gather, futures, errors=errors, - direct=direct) + direct=direct, local_worker=local_worker) @gen.coroutine - def _scatter(self, data, workers=None, broadcast=False, direct=False): + def _scatter(self, data, workers=None, broadcast=False, direct=False, + local_worker=None): if isinstance(workers, six.string_types): workers = [workers] if isinstance(data, dict) and not all(isinstance(k, (bytes, unicode)) @@ -1305,24 +1368,34 @@ def _scatter(self, data, workers=None, broadcast=False, direct=False): assert isinstance(data, dict) - data2 = valmap(to_serialize, data) types = valmap(type, data) - if direct: - ncores = yield self.scheduler.ncores(workers=workers) - if not ncores: - raise ValueError("No valid workers") - _, who_has, nbytes = yield scatter_to_workers(ncores, data2, - report=False, - rpc=self.rpc) + if local_worker: # running within task + local_worker.update_data(data=data, report=False) + + yield self.scheduler.update_data( + who_has={key: [local_worker.address] for key in data}, + nbytes=valmap(sizeof, data), + client=self.id) - yield self.scheduler.update_data(who_has=who_has, nbytes=nbytes) else: - yield self.scheduler.scatter(data=data2, workers=workers, - client=self.id, - broadcast=broadcast) + data2 = valmap(to_serialize, data) + if direct: + ncores = yield self.scheduler.ncores(workers=workers) + if not ncores: + raise ValueError("No valid workers") - out = {k: self._Future(k, self) for k in data2} + _, who_has, nbytes = yield scatter_to_workers(ncores, data2, + report=False, + rpc=self.rpc) + + yield self.scheduler.update_data(who_has=who_has, nbytes=nbytes) + else: + yield self.scheduler.scatter(data=data2, workers=workers, + client=self.id, + broadcast=broadcast) + + out = {k: self._Future(k, self) for k in data} for key, typ in types.items(): self.futures[key].finish(type=typ) @@ -1438,8 +1511,13 @@ def scatter(self, data, workers=None, broadcast=False, direct=False, maxsize=0): else: return queue_to_iterator(qout) else: + if hasattr(thread_state, 'execution_state'): # inside worker task + local_worker = thread_state.execution_state['worker'] + else: + local_worker = None return self.sync(self._scatter, data, workers=workers, - broadcast=broadcast, direct=direct) + broadcast=broadcast, direct=direct, + local_worker=local_worker) @gen.coroutine def _cancel(self, futures): @@ -1743,8 +1821,8 @@ def _graph_to_futures(self, dsk, keys, restrictions=None, 'restrictions': restrictions or {}, 'loose_restrictions': loose_restrictions, 'priority': priority, - 'resources': resources}) - + 'resources': resources, + 'submitting_task': getattr(thread_state, 'key', None)}) return futures def get(self, dsk, keys, restrictions=None, loose_restrictions=None, diff --git a/distributed/queues.py b/distributed/queues.py index d633181bda0..bb2b617571d 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -14,6 +14,7 @@ from .client import Future, _get_global_client, Client from .utils import tokey, sync +from .worker import get_client logger = logging.getLogger(__name__) @@ -240,7 +241,9 @@ def __getstate__(self): def __setstate__(self, state): name, address = state - client = _get_global_client() - if client is None or client.scheduler.address != address: + try: + client = get_client(address) + assert client.address == address + except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 79b8d54a666..1a10358a29f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -38,7 +38,6 @@ from .utils_comm import (scatter_to_workers, gather_from_workers) from .versions import get_versions -from .channels import ChannelScheduler from .publish import PublishExtension from .queues import QueueExtension from .recreate_exceptions import ReplayExceptionScheduler @@ -191,7 +190,7 @@ class Scheduler(ServerNode): def __init__(self, center=None, loop=None, delete_interval=500, synchronize_worker_interval=60000, services=None, allowed_failures=ALLOWED_FAILURES, - extensions=[ChannelScheduler, PublishExtension, WorkStealing, + extensions=[PublishExtension, WorkStealing, ReplayExceptionScheduler, QueueExtension, VariableExtension], validate=False, scheduler_file=None, security=None, @@ -297,7 +296,8 @@ def __init__(self, center=None, loop=None, 'task-erred': self.handle_task_erred, 'release': self.handle_missing_data, 'release-worker-data': self.release_worker_data, - 'add-keys': self.add_keys} + 'add-keys': self.add_keys, + 'long-running': self.handle_long_running} self.client_handlers = {'update-graph': self.update_graph, 'client-desires-keys': self.client_desires_keys, @@ -643,7 +643,8 @@ def add_worker(self, comm=None, address=None, keys=(), ncores=None, def update_graph(self, client=None, tasks=None, keys=None, dependencies=None, restrictions=None, priority=None, - loose_restrictions=None, resources=None): + loose_restrictions=None, resources=None, + submitting_task=None): """ Add new computations to the internal dask graph @@ -700,10 +701,17 @@ def update_graph(self, client=None, tasks=None, keys=None, recommendations = OrderedDict() new_priority = priority or order(tasks) # TODO: define order wrt old graph - self.generation += 1 # older graph generations take precedence + if submitting_task: # sub-tasks get better priority than parent tasks + try: + generation = self.priority[submitting_task][0] - 0.01 + except KeyError: # super-task already cleaned up + generation = self.generation + else: + self.generation += 1 # older graph generations take precedence + generation = self.generation for key in set(new_priority) & touched: if key not in self.priority: - self.priority[key] = (self.generation, new_priority[key]) # prefer old + self.priority[key] = (generation, new_priority[key]) # prefer old if restrictions: # *restrictions* is a dict keying task ids to lists of @@ -1294,6 +1302,37 @@ def release_worker_data(self, stream=None, keys=None, worker=None): if recommendations: self.transitions(recommendations) + def handle_long_running(self, key=None, worker=None, compute_duration=None): + """ A task has seceded from the thread pool + + We stop the task from being stolen in the future, and change task + duration accounting as if the task has stopped. + """ + self.extensions['stealing'].remove_key_from_stealable(key) + + try: + actual_worker = self.rprocessing[key] + except KeyError: + logger.debug("Received long-running signal from duplicate task. " + "Ignoring.") + return + + if compute_duration: + ks = key_split(key) + old_duration = self.task_duration.get(ks, 0) + new_duration = compute_duration + if not old_duration: + avg_duration = new_duration + else: + avg_duration = (0.5 * old_duration + + 0.5 * new_duration) + + self.task_duration[ks] = avg_duration + + worker = self.rprocessing[key] + self.occupancy[actual_worker] -= self.processing[actual_worker][key] + self.processing[actual_worker][key] = 0 + @gen.coroutine def handle_worker(self, worker): """ @@ -2308,7 +2347,7 @@ def transition_processing_memory(self, key, nbytes=None, type=None, ############################# # Update Timing Information # ############################# - if compute_start: + if compute_start and self.processing[worker].get(key, True): # Update average task duration for worker info = self.worker_info[worker] ks = key_split(key) diff --git a/distributed/stealing.py b/distributed/stealing.py index 79f371d2aad..b97baf9d766 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -51,8 +51,6 @@ def __init__(self, scheduler): self.scheduler.events['stealing'] = deque(maxlen=100000) self.count = 0 - scheduler.worker_handlers['long-running'] = self.transition_long_running - @property def log(self): return self.scheduler.events['stealing'] @@ -80,9 +78,6 @@ def transition(self, key, start, finish, compute_start=None, if self.scheduler.task_state[k] == 'processing': self.put_key_in_stealable(k, split=ks) - def transition_long_running(self, key=None, worker=None): - self.remove_key_from_stealable(key) - def put_key_in_stealable(self, key, split=None): worker = self.scheduler.rprocessing[key] cost_multiplier, level = self.steal_time_ratio(key, split=split) diff --git a/distributed/tests/py3_test_asyncio.py b/distributed/tests/py3_test_asyncio.py index a99e7748ac7..438c6ee3153 100644 --- a/distributed/tests/py3_test_asyncio.py +++ b/distributed/tests/py3_test_asyncio.py @@ -179,46 +179,6 @@ async def test_asyncio_exceptions(): assert result == 10 / 2 -@coro_test -async def test_asyncio_channels(): - async with AioClient(processes=False) as c: - x = c.channel('x') - y = c.channel('y') - - assert len(x) == 0 - - while set(c.extensions['channels'].channels) != {'x', 'y'}: - await asyncio.sleep(0.01) - - xx = c.channel('x') - yy = c.channel('y') - - assert len(x) == 0 - - await asyncio.sleep(0.1) - assert set(c.extensions['channels'].channels) == {'x', 'y'} - - future = c.submit(inc, 1) - - x.append(future) - - while not x.data: - await asyncio.sleep(0.01) - - assert len(x) == 1 - - assert xx.data[0].key == future.key - - xxx = c.channel('x') - while not xxx.data: - await asyncio.sleep(0.01) - - assert xxx.data[0].key == future.key - - assert 'x' in repr(x) - assert '1' in repr(x) - - @coro_test async def test_asyncio_exception_on_exception(): async with AioClient(processes=False) as c: diff --git a/distributed/tests/test_channels.py b/distributed/tests/test_channels.py deleted file mode 100644 index 20bf8d1427c..00000000000 --- a/distributed/tests/test_channels.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import print_function, division, absolute_import - -from operator import add -from time import sleep - -import pytest -from toolz import take -from tornado import gen - -from distributed import Client -from distributed import worker_client -from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, loop, cluster, slowinc - - -@gen_cluster(client=True) -def test_channel(c, s, a, b): - x = c.channel('x') - y = c.channel('y') - - assert len(x) == 0 - - while set(c.extensions['channels'].channels) != {'x', 'y'}: - yield gen.sleep(0.01) - - xx = c.channel('x') - yy = c.channel('y') - - assert len(x) == 0 - - yield gen.sleep(0.1) - assert set(c.extensions['channels'].channels) == {'x', 'y'} - - future = c.submit(inc, 1) - - x.append(future) - - while not x.data: - yield gen.sleep(0.01) - - assert len(x) == 1 - - assert xx.data[0].key == future.key - - xxx = c.channel('x') - while not xxx.data: - yield gen.sleep(0.01) - - assert xxx.data[0].key == future.key - - assert 'x' in repr(x) - assert '1' in repr(x) - - -def test_worker_client(loop): - def produce(n): - with worker_client() as c: - x = c.channel('x') - for i in range(n): - future = c.submit(slowinc, i, delay=0.01, key='f-%d' % i) - x.append(future) - - x.flush() - - def consume(): - with worker_client() as c: - x = c.channel('x') - y = c.channel('y') - last = 0 - for i, future in enumerate(x): - last = c.submit(add, future, last, key='add-' + future.key) - y.append(last) - - y.flush() - - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: - x = c.channel('x') - y = c.channel('y') - - producers = (c.submit(produce, 5), c.submit(produce, 10)) - consumer = c.submit(consume) - - results = [] - for i, future in enumerate(take(15, y)): - result = future.result() - results.append(result) - - assert len(results) == 15 - assert all(0 < r < 100 for r in results) - - -@gen_cluster(client=True) -def test_channel_scheduler(c, s, a, b): - chan = c.channel('chan', maxlen=5) - - x = c.submit(inc, 1) - key = x.key - chan.append(x) - del x - - while not len(chan): - yield gen.sleep(0.01) - - assert 'streaming-chan' in s.who_wants[key] - assert s.wants_what['streaming-chan'] == {key} - - while len(s.who_wants[key]) < 2: - yield gen.sleep(0.01) - - assert s.wants_what[c.id] == {key} - - for i in range(10): - chan.append(c.submit(inc, i)) - - start = time() - while True: - if len(chan) == len(s.task_state) == 5: - break - else: - assert time() < start + 2 - yield gen.sleep(0.01) - - results = yield c.gather(list(chan.data)) - assert results == [6, 7, 8, 9, 10] - - -@gen_cluster(client=True) -def test_multiple_maxlen(c, s, a, b): - c2 = yield Client((s.ip, s.port), asynchronous=True) - - x = c.channel('x', maxlen=10) - assert x.data.maxlen == 10 - x2 = c2.channel('x', maxlen=20) - assert x2.data.maxlen == 20 - - for i in range(10): - x.append(c.submit(inc, i)) - - while len(s.wants_what[c2.id]) < 10: - yield gen.sleep(0.01) - - for i in range(10, 20): - x.append(c.submit(inc, i)) - - while len(x2) < 20: - yield gen.sleep(0.01) - - yield gen.sleep(0.1) - - assert len(x2) == 20 # They stay this long after a delay - assert len(s.task_state) == 20 - - yield c2.shutdown() - - -def test_stop(loop): - def produce(n): - with worker_client() as c: - x = c.channel('x') - for i in range(n): - future = c.submit(slowinc, i, delay=0.01, key='f-%d' % i) - x.append(future) - - x.stop() - x.flush() - - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - - producer = c.submit(produce, 5) - - futures = list(x) - assert len(futures) == 5 - - with pytest.raises(StopIteration): - x.append(c.submit(inc, 1)) - - with Client(s['address']) as c2: - xx = c2.channel('x') - futures = list(xx) - assert len(futures) == 5 - - -@gen_cluster(client=True) -def test_values(c, s, a, b): - c2 = yield Client((s.ip, s.port), asynchronous=True) - - x = c.channel('x') - x2 = c2.channel('x') - - data = [123, 'Hello!', {'x': [1, 2, 3]}] - for item in data: - x.append(item) - - while len(x2.data) < 3: - yield gen.sleep(0.01) - - assert list(x2.data) == data - - yield c2.shutdown() - - -def test_channel_gets_updates_immediately(loop): - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - future = c.submit(inc, 1) - x.append(future) - x.flush() - - with Client(s['address']) as c: - x = c.channel('x') - future = next(iter(x)) - assert future.result() == 2 - - -def test_channel_gets_updates_immediately_2(loop): - with cluster() as (s, [a, b]): - with Client(s['address']) as c: - x = c.channel('x') - - with Client(s['address']) as c2: - x2 = c.channel('x') - future = c2.submit(inc, 1) - x2.append(future) - x2.flush() - - future = next(iter(x)) - assert future.result() == 2 diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c576ba8563d..461f7dc7dbe 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -11,6 +11,7 @@ import pickle import random import sys +import threading from threading import Thread, Semaphore from time import sleep import traceback @@ -27,7 +28,8 @@ import dask from dask import delayed from dask.context import _globals -from distributed import Worker, Nanny, recreate_exceptions, fire_and_forget +from distributed import (Worker, Nanny, recreate_exceptions, fire_and_forget, + get_client, secede, get_worker) from distributed.comm import CommClosedError from distributed.utils_comm import WrappedKey from distributed.client import (Client, Future, _wait, @@ -4226,6 +4228,20 @@ def test_quiet_client_shutdown(loop): assert not out +@gen_cluster() +def test_close(s, a, b): + c = yield Client(s.address, asynchronous=True) + future = c.submit(inc, 1) + yield wait(future) + assert c.id in s.wants_what + yield c.close() + + start = time() + while c.id in s.wants_what or s.task_state: + yield gen.sleep(0.01) + assert time() < start + 5 + + def test_threadsafe(loop): with cluster() as (s, [a, b]): with Client(s['address'], loop=loop) as c: @@ -4239,9 +4255,9 @@ def f(_): total = c.submit(sum, list(d)) return total.result() - from multiprocessing.pool import ThreadPool - pool = ThreadPool(20) - results = pool.map(f, range(20)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(20) + results = list(e.map(f, range(20))) assert results and all(results) @@ -4258,9 +4274,9 @@ def f(_): sleep(0.001) return total - from multiprocessing.pool import ThreadPool - pool = ThreadPool(30) - results = pool.map(f, range(30)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) + results = list(e.map(f, range(30))) assert results and all(results) @@ -4278,9 +4294,9 @@ def f(_): sleep(0.001) return total - from multiprocessing.pool import ThreadPool - pool = ThreadPool(30) - results = pool.map(f, range(30)) + from concurrent.futures import ThreadPoolExecutor + e = ThreadPoolExecutor(30) + results = list(e.map(f, range(30))) assert results and all(results) @@ -4292,6 +4308,159 @@ def test_identity(c, s, a, b): assert s.id.lower().startswith('scheduler') +@gen_cluster(client=True, ncores=[('127.0.0.1', 4)] * 2) +def test_get_client(c, s, a, b): + assert get_client() is c + assert c.asynchronous + def f(x): + client = get_client() + future = client.submit(inc, x) + import distributed + assert not client.asynchronous + assert client is distributed.tmp_client + return future.result() + + import distributed + distributed.tmp_client = c + try: + futures = c.map(f, range(5)) + results = yield c.gather(futures) + assert results == list(map(inc, range(5))) + finally: + del distributed.tmp_client + + +@gen_cluster(client=True) +def test_serialize_collections(c, s, a, b): + da = pytest.importorskip('dask.array') + x = da.arange(10, chunks=(5,)).persist() + + def f(x): + assert isinstance(x, da.Array) + return x.sum().compute() + + future = c.submit(f, x) + result = yield future + assert result == sum(range(10)) + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 1, timeout=100) +def test_secede_simple(c, s, a): + def f(): + client = get_client() + secede() + return client.submit(inc, 1).result() + + result = yield c.submit(f) + assert result == 2 + + +@slow +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2, timeout=60) +def test_secede_balances(c, s, a, b): + def f(x): + client = get_client() + sleep(0.01) # do some work + secede() + futures = client.map(slowinc, range(10), pure=False, delay=0.01) + total = client.submit(sum, futures).result() + return total + + futures = c.map(f, range(100)) + start = time() + while not all(f.status == 'finished' for f in futures): + yield gen.sleep(0.01) + assert threading.active_count() < 50 + + # assert 0.005 < s.task_duration['f'] < 0.1 + assert len(a.log) < 2 * len(b.log) + assert len(b.log) < 2 * len(a.log) + + results = yield c.gather(futures) + assert results == [sum(map(inc, range(10)))] * 100 + + +@gen_cluster(client=True) +def test_sub_submit_priority(c, s, a, b): + def f(): + client = get_client() + client.submit(slowinc, 1, delay=0.2) + + future = c.submit(f) + yield gen.sleep(0.1) + if len(s.task_state) == 2: + f_key = [k for k in s.task_state if k.startswith('f')][0] + slowinc_key = [k for k in s.task_state if k.startswith('slowinc')][0] + assert s.priorities[f_key] > s.priorities[slowinc_key] # lower values schedule first + + +def test_get_client_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + results = c.run(lambda: get_worker().scheduler.address) + assert results == {w['address']: s['address'] for w in [a, b]} + + results = c.run(lambda: get_client().scheduler.address) + assert results == {w['address']: s['address'] for w in [a, b]} + + +@gen_cluster(client=True) +def test_serialize_collections_of_futures(c, s, a, b): + pd = pytest.importorskip('pandas') + dd = pytest.importorskip('dask.dataframe') + from dask.dataframe.utils import assert_eq + + df = pd.DataFrame({'x': [1, 2, 3]}) + ddf = dd.from_pandas(df, npartitions=2).persist() + future = yield c.scatter(ddf) + + ddf2 = yield future + df2 = yield c.compute(ddf2) + + assert_eq(df, df2) + + +def test_serialize_collections_of_futures_sync(loop): + pd = pytest.importorskip('pandas') + dd = pytest.importorskip('dask.dataframe') + from dask.dataframe.utils import assert_eq + + df = pd.DataFrame({'x': [1, 2, 3]}) + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + ddf = dd.from_pandas(df, npartitions=2).persist() + future = c.scatter(ddf) + + result = future.result() + assert_eq(result.compute(), df) + + assert future.type == dd.DataFrame + assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() + + +def _dynamic_workload(x, delay=0.01): + if delay == 'random': + sleep(random.random() / 2) + else: + sleep(delay) + if x > 4: + return 4 + secede() + client = get_client() + futures = client.map(_dynamic_workload, [x + i + 1 for i in range(2)], + pure=False, delay=delay) + total = client.submit(sum, futures) + return total.result() + + +@pytest.mark.parametrize('delay', [0.02, slow('random')]) +def test_dynamic_workloads_sync(loop, delay): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + future = c.submit(_dynamic_workload, 0, delay=delay) + assert future.result(timeout=40) == 52 + + @gen_cluster(client=True) def test_bytes_keys(c, s, a, b): key = b'inc-123' diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index aaf8decb2f8..71d0fc0894c 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -163,20 +163,20 @@ def test_dataframe_groupby_tasks(loop): with dask.set_options(get=c.get): for ind in [lambda x: 'A', lambda x: x.A]: a = df.groupby(ind(df)).apply(len) - b = ddf.groupby(ind(ddf)).apply(len) + b = ddf.groupby(ind(ddf)).apply(len, meta=int) assert_equal(a, b.compute(get=dask.get).sort_index()) assert not any('partd' in k[0] for k in b.dask) a = df.groupby(ind(df)).B.apply(len) - b = ddf.groupby(ind(ddf)).B.apply(len) + b = ddf.groupby(ind(ddf)).B.apply(len, meta=('B', int)) assert_equal(a, b.compute(get=dask.get).sort_index()) assert not any('partd' in k[0] for k in b.dask) with pytest.raises(NotImplementedError): - ddf.groupby(ddf[['A', 'B']]).apply(len) + ddf.groupby(ddf[['A', 'B']]).apply(len, meta=int) a = df.groupby(['A', 'B']).apply(len) - b = ddf.groupby(['A', 'B']).apply(len) + b = ddf.groupby(['A', 'B']).apply(len, meta=int) assert_equal(a, b.compute(get=dask.get).sort_index()) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a191cfb07a4..ef6f5c7a5f9 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -548,7 +548,7 @@ def long(delay): assert time() < start + 1 na = len(a.executing) - nb= len(b.executing) + nb = len(b.executing) incs = c.map(inc, range(100), workers=a.address, allow_other_workers=True) diff --git a/distributed/tests/test_threadpoolexecutor.py b/distributed/tests/test_threadpoolexecutor.py index fe8a0a8e2c3..3b46724d34f 100644 --- a/distributed/tests/test_threadpoolexecutor.py +++ b/distributed/tests/test_threadpoolexecutor.py @@ -17,8 +17,6 @@ def f(): assert e.submit(f).result() == 1 - assert len(e._threads) == 1 - list(e.map(sleep, [0.01] * 4)) assert len(threads | e._threads) == 3 diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 7ea057a6337..fe0d017f383 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -12,7 +12,7 @@ from toolz import take from tornado import gen -from distributed import Nanny, worker_client +from distributed import Nanny, worker_client, Queue from distributed.client import _wait from distributed.metrics import time from distributed.utils_test import (gen_cluster, tls_only_security, @@ -26,22 +26,21 @@ def gen_tls_cluster(**kwargs): @gen_tls_cluster(client=True) -def test_channel(c, s, a, b): +def test_Queue(c, s, a, b): assert s.address.startswith('tls://') - x = c.channel('x') - y = c.channel('y') + x = Queue('x') + y = Queue('y') - assert len(x) == 0 + size = yield x.qsize() + assert size == 0 future = c.submit(inc, 1) - x.append(future) + yield x.put(future) - while not x.data: - yield gen.sleep(0.01) - - assert len(x) == 1 + future2 = yield x.get() + assert future.key == future2.key @gen_tls_cluster(client=True, timeout=None) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d16ff229757..3403a6f13a8 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -18,7 +18,7 @@ from tornado.ioloop import TimeoutError import distributed -from distributed import Nanny +from distributed import Nanny, Client, get_client, wait, default_client from distributed.core import rpc, connect from distributed.client import _wait from distributed.scheduler import Scheduler @@ -29,8 +29,7 @@ from distributed.worker import Worker, error_message, logger, TOTAL_MEMORY from distributed.utils import ignoring, tmpfile from distributed.utils_test import (loop, inc, mul, gen_cluster, div, dec, - slow, slowinc, throws, gen_test, readone) - + slow, slowinc, throws, gen_test, readone, cluster) def test_worker_ncores(): @@ -806,3 +805,82 @@ def __sizeof__(self): @gen_cluster() def test_pid(s, a, b): assert s.worker_info[a.address]['pid'] == os.getpid() + + +@gen_cluster(client=True) +def test_get_client(c, s, a, b): + def f(x): + cc = get_client() + future = cc.submit(inc, x) + return future.result() + + assert default_client() is c + + future = c.submit(f, 10, workers=a.address) + result = yield future + assert result == 11 + + assert a._client + assert not b._client + + assert a._client is c + assert default_client() is c + + a_client = a._client + + for i in range(10): + yield wait(c.submit(f, i)) + + assert a._client is a_client + + +def test_get_client_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + def f(x): + cc = get_client() + future = cc.submit(inc, x) + return future.result() + + future = c.submit(f, 10) + assert future.result() == 11 + + +@gen_cluster(client=True) +def test_get_client_coroutine(c, s, a, b): + @gen.coroutine + def f(): + client = yield get_client() + future = client.submit(inc, 10) + result = yield future + raise gen.Return(result) + + results = yield c.run_coroutine(f) + assert results == {a.address: 11, + b.address: 11} + + +def test_get_client_coroutine_sync(loop): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + @gen.coroutine + def f(): + client = yield get_client() + future = client.submit(inc, 10) + result = yield future + raise gen.Return(result) + + results = c.run_coroutine(f) + assert results == {a['address']: 11, + b['address']: 11} + + +@gen_cluster() +def test_global_workers(s, a, b): + from distributed.worker import _global_workers + n = len(_global_workers) + w = _global_workers[-1]() + assert w is a or w is b + yield a._close() + yield b._close() + assert len(_global_workers) == n - 2 diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index cdfbef8b094..e208adf2184 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -2,15 +2,17 @@ import random from time import sleep +import warnings import dask from dask import delayed import pytest from tornado import gen -from distributed import worker_client, Client, as_completed +from distributed import worker_client, Client, as_completed, get_worker from distributed.metrics import time from distributed.utils_test import gen_cluster, inc, double, cluster, loop +from distributed.worker import thread_state @gen_cluster(client=True) @@ -36,12 +38,13 @@ def func(x): @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2) def test_scatter_from_worker(c, s, a, b): def func(): + from distributed.worker import thread_state with worker_client() as c: futures = c.scatter([1, 2, 3, 4, 5]) assert isinstance(futures, (list, tuple)) assert len(futures) == 5 - x = dict(c.worker.data) + x = dict(get_worker().data) y = {f.key: i for f, i in zip(futures, [1, 2, 3, 4, 5])} assert x == y @@ -61,7 +64,7 @@ def func(): o = object() futures = c.scatter({'x': o}) - correct &= c.worker.data['x'] is o + correct &= get_worker().data['x'] is o return correct future = c.submit(func) @@ -110,7 +113,7 @@ def func(): def test_same_loop(c, s, a, b): def f(): with worker_client() as lc: - return lc.loop is lc.worker.loop + return lc.loop is get_worker().loop future = c.submit(f) result = yield future @@ -159,11 +162,11 @@ def test_separate_thread_false(c, s, a): a.count = 0 def f(i): with worker_client(separate_thread=False) as client: - client.worker.count += 1 - assert client.worker.count <= 3 + get_worker().count += 1 + assert get_worker().count <= 3 sleep(random.random() / 40) - assert client.worker.count <= 3 - client.worker.count -= 1 + assert get_worker().count <= 3 + get_worker().count -= 1 return i futures = c.map(f, range(20)) @@ -198,3 +201,19 @@ def f(x): b2.compute() assert dask.context._globals['get'] == c.get + + +@gen_cluster(client=True) +def test_local_client_warning(c, s, a, b): + from distributed import local_client + def func(x): + with warnings.catch_warnings(record=True) as record: + with local_client() as c: + x = c.submit(inc, x) + result = x.result() + assert any("worker_client" in str(r.message) for r in record) + return result + + future = c.submit(func, 10) + result = yield future + assert result == 11 diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index efd7aad9a37..ba9d0ce460a 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -24,13 +24,13 @@ from concurrent.futures import thread import logging -from threading import local, Thread +import threading from .compatibility import get_thread_identity logger = logging.getLogger(__name__) -thread_state = local() +thread_state = threading.local() def _worker(executor, work_queue): @@ -57,23 +57,31 @@ def _worker(executor, work_queue): class ThreadPoolExecutor(thread.ThreadPoolExecutor): def _adjust_thread_count(self): if len(self._threads) < self._max_workers: - t = Thread(target=_worker, - name="ThreadPool worker %d" % len(self._threads,), - args=(self, self._work_queue)) + t = threading.Thread(target=_worker, + name="ThreadPool worker %d" % len(self._threads,), + args=(self, self._work_queue)) t.daemon = True self._threads.add(t) t.start() + def shutdown(self): + with threads_lock: + thread.ThreadPoolExecutor.shutdown(self) + def secede(): """ Have this thread secede from the ThreadPoolExecutor """ thread_state.proceed = False ident = get_thread_identity() - for t in list(thread_state.executor._threads): - if t.ident == ident: - thread_state.executor._threads.remove(t) - break + with threads_lock: + for t in list(thread_state.executor._threads): + if t.ident == ident: + thread_state.executor._threads.remove(t) + break + thread_state.executor._adjust_thread_count() + +threads_lock = threading.Lock() """ PSF LICENSE AGREEMENT FOR PYTHON 3.5.2 diff --git a/distributed/variable.py b/distributed/variable.py index e4310b4bd07..aaa2d53ecf7 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -15,6 +15,7 @@ from .client import Future, _get_global_client, Client from .metrics import time from .utils import tokey, log_errors +from .worker import get_client logger = logging.getLogger(__name__) @@ -198,7 +199,9 @@ def __getstate__(self): def __setstate__(self, state): name, address = state - client = _get_global_client() - if client is None or client.scheduler.address != address: + try: + client = get_client(address) + assert client.address == address + except (AttributeError, AssertionError): client = Client(address, set_as_default=False) self.__init__(name=name, client=client) diff --git a/distributed/worker.py b/distributed/worker.py index 0a7dafd3948..cd894cf7f13 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -8,9 +8,10 @@ from pickle import PicklingError import random import tempfile -from threading import current_thread, local +import threading import shutil import sys +import weakref from dask.core import istask from dask.compatibility import apply @@ -26,7 +27,7 @@ from .batched import BatchedSend from .comm import get_address_host, get_local_address_for from .config import config -from .compatibility import unicode +from .compatibility import unicode, get_thread_identity from .core import (error_message, CommClosedError, rpc, pingpong, coerce_to_address) from .metrics import time @@ -36,7 +37,7 @@ serialize_bytelist) from .security import Security from .sizeof import safe_sizeof as sizeof -from .threadpoolexecutor import ThreadPoolExecutor +from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import (funcname, get_ip, has_arg, _maybe_complex, log_errors, ignoring, validate_key, mp_context, import_file, silence_logging) @@ -44,7 +45,7 @@ _ncores = mp_context.cpu_count() -thread_state = local() +thread_state = threading.local() logger = logging.getLogger(__name__) @@ -66,6 +67,8 @@ PROCESSING = ('waiting', 'ready', 'constrained', 'executing', 'long-running') READY = ('ready', 'constrained') +_global_workers = [] + class WorkerBase(ServerNode): def __init__(self, scheduler_ip, scheduler_port=None, ncores=None, @@ -317,6 +320,8 @@ def _close(self, report=True, timeout=10, nanny=True): if self.status in ('closed', 'closing'): return logger.info("Stopping worker at %s", self.address) + if self._client: + yield self._client._close() self.status = 'closing' self.stop() self.heartbeat_callback.stop() @@ -341,6 +346,18 @@ def _close(self, report=True, timeout=10, nanny=True): self.rpc.close() self._closed.set() + self._remove_from_global_workers() + + def __del__(self): + self._remove_from_global_workers() + + def _remove_from_global_workers(self): + with log_errors(): + for ref in list(_global_workers): + if ref() is self: + _global_workers.remove(ref) + if ref() is None: + _global_workers.remove(ref) @gen.coroutine def terminate(self, comm, report=True): @@ -652,6 +669,7 @@ def apply_function(function, args, kwargs, execution_state, key): ------- msg: dictionary with status, result/error, timings, etc.. """ + thread_state.start_time = time() thread_state.execution_state = execution_state thread_state.key = key start = time() @@ -670,7 +688,7 @@ def apply_function(function, args, kwargs, execution_state, key): end = time() msg['start'] = start msg['stop'] = end - msg['thread'] = current_thread().ident + msg['thread'] = get_thread_identity() return msg @@ -944,6 +962,7 @@ def __init__(self, *args, **kwargs): self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.extensions = {} + self._lock = threading.Lock() self.data_needed = deque() # TODO: replace with heap? @@ -1005,9 +1024,12 @@ def __init__(self, *args, **kwargs): self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=(100000)) self.outgoing_count = 0 + self._client = None WorkerBase.__init__(self, *args, **kwargs) + _global_workers.append(weakref.ref(self)) + def __str__(self): return "<%s: %s, %s, stored: %d, running: %d/%d, ready: %d, comm: %d, waiting: %d>" % ( self.__class__.__name__, self.address, self.status, @@ -1399,14 +1421,16 @@ def transition_executing_done(self, key, value=no_value): import pdb; pdb.set_trace() raise - def transition_executing_long_running(self, key): + def transition_executing_long_running(self, key, compute_duration=None): try: if self.validate: assert key in self.executing self.executing.remove(key) self.long_running.add(key) - self.batched_stream.send({'op': 'long-running', 'key': key}) + self.batched_stream.send({'op': 'long-running', + 'key': key, + 'compute_duration': compute_duration}) self.ensure_computing() except Exception as e: @@ -1415,6 +1439,10 @@ def transition_executing_long_running(self, key): import pdb; pdb.set_trace() raise + def maybe_transition_long_running(self, key, compute_duration=None): + if self.task_state.get(key) == 'executing': + self.transition(key, 'long-running', compute_duration=compute_duration) + ########################## # Gather Data from Peers # ########################## @@ -2103,6 +2131,40 @@ def story(self, *keys): for c in msg if isinstance(c, (tuple, list, set)))] + @property + def client(self): + """ Get local client attached to this worker + + If no such client exists, create one + + See Also + -------- + get_client + """ + with self._lock: + if self._client: + return self._client + + try: + from .client import default_client + client = default_client() + except ValueError: # no clients found, need to make a new one + pass + else: + if client.scheduler.address == self.scheduler.address: + self._client = client + + if not self._client: + from .client import Client + asynchronous = get_thread_identity() == self.loop._thread_ident + self._client = Client(self.scheduler.address, loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous) + if not asynchronous: + assert self._client.status == 'running' + return self._client + def get_worker(): """ Get the worker currently running this task @@ -2119,6 +2181,83 @@ def get_worker(): See Also -------- + get_client + worker_client + """ + try: + return thread_state.execution_state['worker'] + except AttributeError: + for ref in _global_workers[::-1]: + worker = ref() + if worker: + return worker + raise ValueError("No workers found") + + +def get_client(address=None): + """ Get a client while within a task + + This client connects to the same scheduler to which the worker is connected + + Examples + -------- + >>> def f(): + ... client = get_client() + ... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks + ... results = client.gather(futures) + ... return sum(results) + + >>> future = client.submit(f) # doctest: +SKIP + >>> future.result() # doctest: +SKIP + 55 + + See Also + -------- + get_worker worker_client + secede + """ + try: + worker = get_worker() + except ValueError: # could not find worker + pass + else: + if not address or worker.scheduler.address == address: + return worker.client + + from .client import _get_global_client + client = _get_global_client() # TODO: assumes the same scheduler + if client and not address or client.scheduler.address == address: + return client + elif address: + from .client import Client + return Client(address) + else: + raise ValueError("No global client found and no address provided") + + +def secede(): + """ + Have this task secede from the worker's thread pool + + This opens up a new scheduling slot and a new thread for a new task. + + Examples + -------- + >>> def mytask(x): + ... # do some work + ... client = get_client() + ... futures = client.map(...) # do some remote work + ... secede() # while that work happens, remove ourself from the pool + ... return client.gather(futures) # return gathered results + + See Also + -------- + get_client + get_worker """ - return thread_state.execution_state['worker'] + worker = get_worker() + tpe_secede() # have this thread secede from the thread pool + duration = time() - thread_state.start_time + worker.loop.add_callback(worker.maybe_transition_long_running, + thread_state.key, compute_duration=duration) diff --git a/distributed/worker_client.py b/distributed/worker_client.py index a6b9b215462..63ddfba3973 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -1,17 +1,9 @@ from __future__ import print_function, division, absolute_import from contextlib import contextmanager -from datetime import timedelta -from toolz import keymap, valmap, merge +import warnings -from dask.base import tokenize -from tornado import gen - -from .client import AllExit, Client, Future, pack_data, unpack_remotedata -from dask.compatibility import apply -from .sizeof import sizeof from .threadpoolexecutor import secede -from .utils import All, log_errors, sync, tokey, ignoring from .worker import thread_state, get_worker @@ -45,128 +37,17 @@ def worker_client(timeout=3, separate_thread=True): See Also -------- get_worker + get_client + secede """ - address = thread_state.execution_state['scheduler'] - worker = thread_state.execution_state['worker'] + worker = get_worker() if separate_thread: secede() # have this thread secede from the thread pool worker.loop.add_callback(worker.transition, thread_state.key, 'long-running') - with WorkerClient(address, loop=worker.loop, security=worker.security, - asynchronous=True) as wc: - # Make sure connection errors are bubbled to the caller - sync(wc.loop, gen.with_timeout, timedelta(seconds=timeout), wc._started) - wc.asynchronous = False - assert wc.status == 'running' - yield wc - - -local_client = worker_client - - -class WorkerClient(Client): - """ An Client designed to operate from a Worker process - - This client has had a few methods altered to make it more efficient for - working directly from the worker nodes. In particular scatter/gather first - look to the local data dictionary rather than sending data over the network - """ - def __init__(self, *args, **kwargs): - loop = kwargs.get('loop') - self.worker = get_worker() - kwargs['set_as_default'] = False - sync(loop, apply, Client.__init__, (self,) + args, kwargs) - - @gen.coroutine - def _scatter(self, data, workers=None, broadcast=False, direct=None): - """ Scatter data to local data dictionary - - Rather than send data out to the cluster we keep data local. However - we do report to the scheduler that the local worker has the scattered - data. This allows other workers to come by and steal this data if - desired. - - Keywords like ``broadcast=`` do not work, however operations like - ``.replicate`` work fine after calling scatter, which can fill in for - this functionality. - """ - with log_errors(): - if not (workers is None and broadcast is False): - raise NotImplementedError("Scatter from worker doesn't support workers or broadcast keywords") - - if isinstance(data, dict) and not all(isinstance(k, (bytes, str)) - for k in data): - d = yield self._scatter(keymap(tokey, data), workers, broadcast) - raise gen.Return({k: d[tokey(k)] for k in data}) - - if isinstance(data, type(range(0))): - data = list(data) - input_type = type(data) - names = False - unpack = False - if isinstance(data, (set, frozenset)): - data = list(data) - if not isinstance(data, (dict, list, tuple, set, frozenset)): - unpack = True - data = [data] - if isinstance(data, (list, tuple)): - names = list(map(tokenize, data)) - data = dict(zip(names, data)) - - types = valmap(type, data) - assert isinstance(data, dict) - - self.worker.update_data(data=data, report=False) - - yield self.scheduler.update_data( - who_has={key: [self.worker.address] for key in data}, - nbytes=valmap(sizeof, data), - client=self.id) - - out = {k: self._Future(k, self) for k in data} - for key, typ in types.items(): - self.futures[key].finish(type=typ) - - if issubclass(input_type, (list, tuple, set, frozenset)): - out = input_type(out[k] for k in names) - - if unpack: - assert len(out) == 1 - out = list(out.values())[0] - raise gen.Return(out) - - @gen.coroutine - def _gather(self, futures, errors='raise', direct=False): - """ - - Exactly like Client._gather, but get data directly from the local - worker data dictionary directly rather than through the scheduler. - - TODO: avoid scheduler for other communications, and assume that we can - communicate directly with the other workers. - """ - futures2, keys = unpack_remotedata(futures, byte_keys=True) - keys = [tokey(k) for k in keys] - - @gen.coroutine - def wait(k): - """ Want to stop the All(...) early if we find an error """ - yield self.futures[k].event.wait() - if self.futures[k].status != 'finished': - raise AllExit() - - with ignoring(AllExit): - yield All([wait(key) for key in keys if key in self.futures]) - - local = {k: self.worker.data[k] for k in keys - if k in self.worker.data} - - futures3 = {k: Future(k, self) for k in keys if k not in local} + yield worker.client - futures4 = pack_data(futures2, merge(local, futures3)) - if not futures3: - raise gen.Return(futures4) - result = yield Client._gather(self, futures4, errors=errors, - direct=True) - raise gen.Return(result) +def local_client(*args, **kwargs): + warnings.warn("local_client has moved to worker_client") + return worker_client(*args, **kwargs) diff --git a/docs/source/api.rst b/docs/source/api.rst index bf74fef7ea5..33623277b38 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -42,6 +42,8 @@ API .. autosummary:: worker_client get_worker + get_client + secede .. currentmodule:: distributed.recreate_exceptions @@ -141,6 +143,8 @@ Other .. autofunction:: distributed.worker_client .. autofunction:: distributed.get_worker +.. autofunction:: distributed.get_client +.. autofunction:: distributed.secede .. autoclass:: Queue :members: diff --git a/docs/source/channels.rst b/docs/source/channels.rst deleted file mode 100644 index ad1ff1bb42e..00000000000 --- a/docs/source/channels.rst +++ /dev/null @@ -1,149 +0,0 @@ -Shared Futures With Channels -============================ - -A channel is a changing stream of data or futures shared between any number of -clients connected to the same scheduler. - -Any client can push msgpack-encodable data (numbers, strings, lists, dicts) or -Future objects into a channel. All other clients listening to the channel will -receive that data or future as a linear sequence. Communication happens -through the scheduler, which linearizes everything. Channels should only be -used to move small amounts of administrative data. For larger data, create -futures and send the futures instead. - - -Examples --------- - -Basic Usage -~~~~~~~~~~~ - -Create channels from your Client: - -.. code-block:: python - - >>> client = Client('scheduler-address:8786') - >>> chan = client.channel('my-channel') - -Append futures or data onto a channel - -.. code-block:: python - - >>> future = client.submit(add, 1, 2) - >>> chan.append(future) - >>> chan.append({'x': 123}) - -A channel maintains a collection of current futures added by both your -client, and others. - -.. code-block:: python - - >>> chan.data - deque([, - {'x': 123}]) - -If you wish to persist the current status of tasks outside of the distributed -cluster (e.g. to take a snapshot in case of shutdown) you can copy a channel full -of data as it is a python deque_. - -.. _deque: https://docs.python.org/3.5/library/collections.html#collections.deque` - -.. code-block:: python - - channelcopy = list(chan.data) - -You can iterate over a channel to get back data. - -.. code-block:: python - - >>> anotherclient = Client('scheduler-address:8786') - >>> chan = anotherclient.channel('my-channel') - >>> for element in chan: - ... pass - -When done writing, call flush to wait until your appends have been -fully registered with the scheduler. - -.. code-block:: python - - >>> client = Client('scheduler-address:8786') - >>> chan = client.channel('my-channel') - >>> future2 = client.submit(time.sleep, 2) - >>> chan.append(future2) - >>> chan.flush() - - -Example with worker_client -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Using channels with `worker_client`_ allows for a more decoupled version -of what is possible with :doc:`Data Streams with Queues` -in that independent worker clients can build up a set of results -which can be read later by a different client. -This opens up Dask/Distributed to being integrated in a wider application -environment similar to other python task queues such as Celery_. - -.. _worker_client: http://distributed.readthedocs.io/en/latest/task-launch.html#submit-tasks-from-worker -.. _Celery: http://www.celeryproject.org/ - -.. code-block:: python - - import random, time, operator - from distributed import Client, worker_client - from time import sleep - - def emit(name): - with worker_client() as c: - chan = c.channel(name) - while True: - future = c.submit(random.random, pure=False) - chan.append(future) - sleep(1) - - def combine(): - with worker_client() as c: - a_chan = c.channel('a') - b_chan = c.channel('b') - out_chan = c.channel('adds') - for a, b in zip(a_chan, b_chan): - future = c.submit(operator.add, a, b) - out_chan.append(future) - - client = Client() - - emitters = (client.submit(emit, 'a'), client.submit(emit, 'b')) - combiner = client.submit(combine) - chan = client.channel('adds') - - - for future in chan: - print(future.result()) - ...: - 1.782009416831722 - ... - -All iterations on a channel by different clients can be stopped using the ``stop`` method - -.. code-block:: python - - chan.stop() - - -Additional Applications ------------------------ - -Channels can serve as a coordination point or semaphore. They can signal -stopping criteria for iterative processes. - -Short lived clients, such as occur when firing off controlling tasks from a web -application, AWS Lambda, or other fire and forget script, often need a place to -store their futures so that in-flight work doesn't get garbage collected. -Because channels act as clients for the purpose of garbage collection (all -futures within a Channel are considered desired) they can serve as this -repository after short-lived clients die off. - -Worker clients can communicate large amounts of data to each other using -channels by first scattering local data to themselves, creating futures, and -then pushing those futures down a shared channel. When subscribers to the -channel gather these futures they will engage the normal high-bandwidth -inter-worker communication mechanism. diff --git a/docs/source/index.rst b/docs/source/index.rst index ffb6b520c1d..55845e2aacf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -104,7 +104,6 @@ Contents adaptive asynchronous - channels configuration ec2 local-cluster