diff --git a/riak/client/__init__.py b/riak/client/__init__.py index ea9abaca..fae5e133 100644 --- a/riak/client/__init__.py +++ b/riak/client/__init__.py @@ -16,7 +16,7 @@ from riak.security import SecurityCreds from riak.util import lazy_property, bytes_to_str, str_to_bytes from six import string_types, PY2 -from riak.client.multiget import MultiGetPool +from riak.client.multi import MultiGetPool, MultiPutPool def default_encoder(obj): @@ -67,8 +67,10 @@ class RiakClient(RiakMapReduceChain, RiakClientOperations): #: The supported protocols PROTOCOLS = ['http', 'pbc'] - def __init__(self, protocol='pbc', transport_options={}, nodes=None, - credentials=None, multiget_pool_size=None, **kwargs): + def __init__(self, protocol='pbc', transport_options={}, + nodes=None, credentials=None, + multiget_pool_size=None, multiput_pool_size=None, + **kwargs): """ Construct a new ``RiakClient`` object. @@ -87,6 +89,10 @@ def __init__(self, protocol='pbc', transport_options={}, nodes=None, :meth:`multiget` operations. Defaults to a factor of the number of CPUs in the system :type multiget_pool_size: int + :param multiput_pool_size: the number of threads to use in + :meth:`multiput` operations. Defaults to a factor of the number of + CPUs in the system + :type multiput_pool_size: int """ kwargs = kwargs.copy() @@ -96,6 +102,7 @@ def __init__(self, protocol='pbc', transport_options={}, nodes=None, self.nodes = [self._create_node(n) for n in nodes] self._multiget_pool_size = multiget_pool_size + self._multiput_pool_size = multiput_pool_size self.protocol = protocol or 'pbc' self._resolver = None self._credentials = self._create_credentials(credentials) @@ -358,6 +365,13 @@ def _multiget_pool(self): else: return None + @lazy_property + def _multiput_pool(self): + if self._multiput_pool_size: + return MultiPutPool(self._multiput_pool_size) + else: + return None + def __hash__(self): return hash(frozenset([(n.host, n.http_port, n.pb_port) for n in self.nodes])) diff --git a/riak/client/multiget.py b/riak/client/multi.py similarity index 59% rename from riak/client/multiget.py rename to riak/client/multi.py index 9b5d7522..672c32b9 100644 --- a/riak/client/multiget.py +++ b/riak/client/multi.py @@ -3,13 +3,12 @@ from threading import Thread, Lock, Event from multiprocessing import cpu_count from six import PY2 - if PY2: from Queue import Queue else: from queue import Queue -__all__ = ['multiget', 'MultiGetPool'] +__all__ = ['multiget', 'multiput', 'MultiGetPool', 'MultiPutPool'] try: @@ -21,18 +20,21 @@ POOL_SIZE = 6 #: A :class:`namedtuple` for tasks that are fed to workers in the -#: multiget pool. -Task = namedtuple('Task', ['client', 'outq', 'bucket_type', 'bucket', 'key', - 'options']) +#: multi pool. +Task = namedtuple( + 'Task', + ['client', 'outq', + 'bucket_type', 'bucket', 'key', + 'object', 'options']) -class MultiGetPool(object): +class MultiPool(object): """ - Encapsulates a pool of fetcher threads. These threads can be used - across many multi-get requests. + Encapsulates a pool of threads. These threads can be used + across many multi requests. """ - def __init__(self, size=POOL_SIZE): + def __init__(self, size=POOL_SIZE, name='unknown'): """ :param size: the desired size of the worker pool :type size: int @@ -40,6 +42,7 @@ def __init__(self, size=POOL_SIZE): self._inq = Queue() self._size = size + self._name = name self._started = Event() self._stop = Event() self._lock = Lock() @@ -57,14 +60,14 @@ def enq(self, task): if not self._stop.is_set(): self._inq.put(task) else: - raise RuntimeError("Attempted to enqueue a fetch operation while " - "multi-get pool was shutdown!") + raise RuntimeError("Attempted to enqueue an operation while " + "multi pool was shutdown!") def start(self): """ Starts the worker threads if they are not already started. This method is thread-safe and will be called automatically - when executing a MultiGet operation. + when executing an operation. """ # Check whether we are already started, skip if we are. if not self._started.is_set(): @@ -73,8 +76,9 @@ def start(self): # If we got the lock, go ahead and start the worker # threads, set the started flag, and release the lock. for i in range(self._size): - name = "riak.client.multiget-worker-{0}".format(i) - worker = Thread(target=self._fetcher, name=name) + name = "riak.client.multi-worker-{0}-{1}".format( + self._name, i) + worker = Thread(target=self._worker_method, name=name) worker.daemon = True worker.start() self._workers.append(worker) @@ -105,7 +109,26 @@ def __del__(self): # shutting down. self.stop() - def _fetcher(self): + def _worker_method(self): + raise NotImplementedError + + def _should_quit(self): + """ + Worker threads should exit when the stop flag is set and the + input queue is empty. Once the stop flag is set, new enqueues + are disallowed, meaning that the workers can safely drain the + queue before exiting. + + :rtype: bool + """ + return self.stopped() and self._inq.empty() + + +class MultiGetPool(MultiPool): + def __init__(self, size=POOL_SIZE): + super(MultiGetPool, self).__init__(size=size, name='get') + + def _worker_method(self): """ The body of the multi-get worker. Loops until :meth:`_should_quit` returns ``True``, taking tasks off the @@ -121,24 +144,40 @@ def _fetcher(self): except KeyboardInterrupt: raise except Exception as err: - task.outq.put((task.bucket_type, task.bucket, task.key, err), ) + errdata = (task.bucket_type, task.bucket, task.key, err) + task.outq.put(errdata) finally: self._inq.task_done() - def _should_quit(self): - """ - Worker threads should exit when the stop flag is set and the - input queue is empty. Once the stop flag is set, new enqueues - are disallowed, meaning that the workers can safely drain the - queue before exiting. - :rtype: bool +class MultiPutPool(MultiPool): + def __init__(self, size=POOL_SIZE): + super(MultiPutPool, self).__init__(size=size, name='put') + + def _worker_method(self): """ - return self.stopped() and self._inq.empty() + The body of the multi-put worker. Loops until + :meth:`_should_quit` returns ``True``, taking tasks off the + input queue, storing the object, and putting the result on + the output queue. + """ + while not self._should_quit(): + task = self._inq.get() + try: + robj = task.object + rv = task.client.put(robj, **task.options) + task.outq.put(rv) + except KeyboardInterrupt: + raise + except Exception as err: + errdata = (task.object, err) + task.outq.put(errdata) + finally: + self._inq.task_done() -#: The default pool is automatically created and stored in this constant. RIAK_MULTIGET_POOL = MultiGetPool() +RIAK_MULTIPUT_POOL = MultiPutPool() def multiget(client, keys, **options): @@ -160,8 +199,8 @@ def multiget(client, keys, **options): :meth:`RiakBucket.get ` :type options: dict :rtype: list - """ + """ outq = Queue() if 'pool' in options: @@ -172,7 +211,7 @@ def multiget(client, keys, **options): pool.start() for bucket_type, bucket, key in keys: - task = Task(client, outq, bucket_type, bucket, key, options) + task = Task(client, outq, bucket_type, bucket, key, None, options) pool.enq(task) results = [] @@ -184,3 +223,48 @@ def multiget(client, keys, **options): outq.task_done() return results + + +def multiput(client, objs, **options): + """Executes a parallel-store across multiple threads. Returns a list + containing booleans or :class:`~riak.riak_object.RiakObject` + + If a ``pool`` option is included, the request will use the given worker + pool and not the default :data:`RIAK_MULTIPUT_POOL`. This option will + be passed by the client if the ``multiput_pool_size`` option was set on + client initialization. + + :param client: the client to use + :type client: :class:`RiakClient ` + :param objs: the Riak Objects to store in parallel + :type keys: list of `RiakObject ` + :param options: request options to + :meth:`RiakClient.put ` + :type options: dict + :rtype: list + """ + outq = Queue() + + if 'pool' in options: + pool = options['pool'] + del options['pool'] + else: + pool = RIAK_MULTIPUT_POOL + + pool.start() + for robj in objs: + bucket_type = robj.bucket.bucket_type + bucket = robj.bucket.name + key = robj.key + task = Task(client, outq, bucket_type, bucket, key, robj, options) + pool.enq(task) + + results = [] + for _ in range(len(objs)): + if pool.stopped(): + raise RuntimeError("Multi-put operation interrupted by pool " + "stopping!") + results.append(outq.get()) + outq.task_done() + + return results diff --git a/riak/client/operations.py b/riak/client/operations.py index d3541b3c..8bf2b9c2 100644 --- a/riak/client/operations.py +++ b/riak/client/operations.py @@ -1,6 +1,7 @@ +import riak.client.multi + from riak.client.transport import RiakClientTransport, \ retryable, retryableHttpOnly -from riak.client.multiget import multiget from riak.client.index_page import IndexPage from riak.datatypes import TYPES from riak.table import Table @@ -976,7 +977,22 @@ def multiget(self, pairs, **params): """ if self._multiget_pool: params['pool'] = self._multiget_pool - return multiget(self, pairs, **params) + return riak.client.multi.multiget(self, pairs, **params) + + def multiput(self, objs, **params): + """ + Stores objects in parallel via threads. + + :param objs: the objects to store + :type objs: list of `RiakObject ` + :param params: additional request flags, e.g. w, dw, pw + :type params: dict + :rtype: list of boolean or + :class:`RiakObjects `, + """ + if self._multiput_pool: + params['pool'] = self._multiput_pool + return riak.client.multi.multiput(self, objs, **params) @retryable def get_counter(self, transport, bucket, key, r=None, pr=None, diff --git a/riak/tests/test_client.py b/riak/tests/test_client.py index 19379d06..d39b3290 100644 --- a/riak/tests/test_client.py +++ b/riak/tests/test_client.py @@ -149,6 +149,40 @@ def test_multiget_errors(self): self.assertIsInstance(failure[3], Exception) client.close() + def test_multiput_errors(self): + """ + Unrecoverable errors are captured along with the bucket/key + and not propagated. + """ + client = self.create_client(http_port=DUMMY_HTTP_PORT, + pb_port=DUMMY_PB_PORT) + bucket = client.bucket(self.bucket_name) + k1 = self.randname() + k2 = self.randname() + o1 = RiakObject(client, bucket, k1) + o2 = RiakObject(client, bucket, k2) + + if PY2: + o1.encoded_data = k1 + o2.encoded_data = k2 + else: + o1.data = k1 + o2.data = k2 + + objs = [o1, o2] + for robj in objs: + robj.content_type = 'text/plain' + + results = client.multiput(objs, return_body=True) + for failure in results: + self.assertIsInstance(failure, tuple) + self.assertIsInstance(failure[0], RiakObject) + if PY2: + self.assertIsInstance(failure[1], StandardError) # noqa + else: + self.assertIsInstance(failure[1], Exception) + client.close() + def test_multiget_notfounds(self): """ Not founds work in multiget just the same as get. @@ -189,6 +223,74 @@ def test_multiget_pool_size(self): self.assertEqual(obj.key, obj.data) client.close() + def test_multiput_pool_size(self): + """ + The pool size for multiputs can be configured at client initiation + time. Multiput still works as expected. + """ + client = self.create_client(multiput_pool_size=2) + self.assertEqual(2, client._multiput_pool._size) + + bucket = client.bucket(self.bucket_name) + k1 = self.randname() + k2 = self.randname() + o1 = RiakObject(client, bucket, k1) + o2 = RiakObject(client, bucket, k2) + + if PY2: + o1.encoded_data = k1 + o2.encoded_data = k2 + else: + o1.data = k1 + o2.data = k2 + + objs = [o1, o2] + for robj in objs: + robj.content_type = 'text/plain' + + results = client.multiput(objs, return_body=True) + for obj in results: + self.assertIsInstance(obj, RiakObject) + self.assertTrue(obj.exists) + self.assertEqual(obj.content_type, 'text/plain') + if PY2: + self.assertEqual(obj.key, obj.encoded_data) + else: + self.assertEqual(obj.key, obj.data) + client.close() + + def test_multiput_pool_options(self): + sz = 4 + client = self.create_client(multiput_pool_size=sz) + self.assertEqual(sz, client._multiput_pool._size) + + bucket = client.bucket(self.bucket_name) + k1 = self.randname() + k2 = self.randname() + o1 = RiakObject(client, bucket, k1) + o2 = RiakObject(client, bucket, k2) + + if PY2: + o1.encoded_data = k1 + o2.encoded_data = k2 + else: + o1.data = k1 + o2.data = k2 + + objs = [o1, o2] + for robj in objs: + robj.content_type = 'text/plain' + + results = client.multiput(objs, return_body=False) + for obj in results: + if client.protocol == 'pbc': + self.assertIsInstance(obj, RiakObject) + self.assertFalse(obj.exists) + self.assertEqual(obj.content_type, 'text/plain') + else: + self.assertIsNone(obj) + client.close() + @unittest.skipUnless(RUN_POOL, 'RUN_POOL is 0') def test_pool_close(self): """