diff --git a/riak/client/transport.py b/riak/client/transport.py index 6aca7f24..bb2aaef9 100644 --- a/riak/client/transport.py +++ b/riak/client/transport.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from riak.transports.pool import BadResource -from riak.transports.tcp import is_retryable as is_pbc_retryable +from riak.transports.tcp import is_retryable as is_tcp_retryable from riak.transports.http import is_retryable as is_http_retryable import threading from six import PY2 @@ -162,7 +162,7 @@ def _is_retryable(error): :type error: Exception :rtype: boolean """ - return is_pbc_retryable(error) or is_http_retryable(error) + return is_tcp_retryable(error) or is_http_retryable(error) def retryable(fn, protocol=None): diff --git a/riak/riak_error.py b/riak/riak_error.py index ce582bbb..b99eb7fe 100644 --- a/riak/riak_error.py +++ b/riak/riak_error.py @@ -21,8 +21,12 @@ class RiakError(Exception): """ Base class for exceptions generated in the Riak API. """ - def __init__(self, value): - self.value = value + def __init__(self, *args, **kwargs): + super(RiakError, self).__init__(*args, **kwargs) + if len(args) > 0: + self.value = args[0] + else: + self.value = 'unknown' def __str__(self): return repr(self.value) @@ -34,5 +38,5 @@ class ConflictError(RiakError): :class:`~riak.riak_object.RiakObject` that has more than one sibling. """ - def __init__(self, message="Object in conflict"): + def __init__(self, message='Object in conflict'): super(ConflictError, self).__init__(message) diff --git a/riak/tests/test_btypes.py b/riak/tests/test_btypes.py index d0fe728b..97d1b1a6 100644 --- a/riak/tests/test_btypes.py +++ b/riak/tests/test_btypes.py @@ -151,13 +151,18 @@ def test_multiget_bucket_types(self): self.assertEqual(btype, mobj.bucket.bucket_type) def test_write_once_bucket_type(self): - btype = self.client.bucket_type('write_once') - bucket = btype.bucket(self.bucket_name) - - for i in range(100): - obj = bucket.new(self.key_name + str(i)) - obj.data = {'id': i} - obj.store() + bt = 'write_once' + skey = 'write_once-init' + btype = self.client.bucket_type(bt) + bucket = btype.bucket(bt) + sobj = bucket.get(skey) + if not sobj.exists: + for i in range(100): + o = bucket.new(self.key_name + str(i)) + o.data = {'id': i} + o.store() + o = bucket.new(skey, data={'id': skey}) + o.store() mget = bucket.multiget([self.key_name + str(i) for i in range(100)]) for mobj in mget: diff --git a/riak/tests/test_client.py b/riak/tests/test_client.py index d39b3290..a9c3a380 100644 --- a/riak/tests/test_client.py +++ b/riak/tests/test_client.py @@ -3,6 +3,7 @@ from six import PY2 from threading import Thread from riak.riak_object import RiakObject +from riak.transports.tcp import TcpTransport from riak.tests import DUMMY_HTTP_PORT, DUMMY_PB_PORT, RUN_POOL from riak.tests.base import IntegrationTestBase @@ -13,6 +14,17 @@ class ClientTests(IntegrationTestBase, unittest.TestCase): + def test_can_set_tcp_keepalive(self): + if self.protocol == 'pbc': + topts = {'socket_keepalive': True} + c = self.create_client(transport_options=topts) + for i, r in enumerate(c._tcp_pool.resources): + self.assertIsInstance(r, TcpTransport) + self.assertTrue(r._socket_keepalive) + c.close() + else: + pass + def test_uses_client_id_if_given(self): if self.protocol == 'pbc': zero_client_id = "\0\0\0\0" diff --git a/riak/tests/test_kv.py b/riak/tests/test_kv.py index 5513c603..aeebed68 100644 --- a/riak/tests/test_kv.py +++ b/riak/tests/test_kv.py @@ -180,17 +180,29 @@ def test_string_bucket_name(self): def test_generate_key(self): # Ensure that Riak generates a random key when # the key passed to bucket.new() is None. - bucket = self.client.bucket('random_key_bucket') - existing_keys = bucket.get_keys() + bucket = self.client.bucket(self.bucket_name) o = bucket.new(None, data={}) self.assertIsNone(o.key) o.store() self.assertIsNotNone(o.key) self.assertNotIn('/', o.key) - self.assertNotIn(o.key, existing_keys) - self.assertEqual(len(bucket.get_keys()), len(existing_keys) + 1) + existing_keys = bucket.get_keys() + self.assertEqual(len(existing_keys), 1) + + def maybe_store_keys(self): + skey = 'rkb-init' + bucket = self.client.bucket('random_key_bucket') + sobj = bucket.get(skey) + if sobj.exists: + return + for key in range(1, 1000): + o = bucket.new(None, data={}) + o.store() + o = bucket.new(skey, data={}) + o.store() def test_stream_keys(self): + self.maybe_store_keys() bucket = self.client.bucket('random_key_bucket') regular_keys = bucket.get_keys() self.assertNotEqual(len(regular_keys), 0) @@ -203,10 +215,8 @@ def test_stream_keys(self): self.assertEqual(sorted(regular_keys), sorted(streamed_keys)) def test_stream_keys_timeout(self): + self.maybe_store_keys() bucket = self.client.bucket('random_key_bucket') - for key in range(1, 1000): - o = bucket.new(None, data={}) - o.store() streamed_keys = [] with self.assertRaises(RiakError): for keylist in self.client.stream_keys(bucket, timeout=1): @@ -216,6 +226,7 @@ def test_stream_keys_timeout(self): streamed_keys += keylist def test_stream_keys_abort(self): + self.maybe_store_keys() bucket = self.client.bucket('random_key_bucket') regular_keys = bucket.get_keys() self.assertNotEqual(len(regular_keys), 0) diff --git a/riak/transports/tcp/__init__.py b/riak/transports/tcp/__init__.py index 312f9194..2634af0a 100644 --- a/riak/transports/tcp/__init__.py +++ b/riak/transports/tcp/__init__.py @@ -42,7 +42,7 @@ def destroy_resource(self, tcp): def is_retryable(err): """ Determines if the given exception is something that is - network/socket-related and should thus cause the PBC connection to + network/socket-related and should thus cause the TCP connection to close and the operation retried on another node. :rtype: boolean diff --git a/riak/transports/tcp/connection.py b/riak/transports/tcp/connection.py index a9d75603..aabcb52e 100644 --- a/riak/transports/tcp/connection.py +++ b/riak/transports/tcp/connection.py @@ -1,3 +1,4 @@ +import errno import socket import struct @@ -7,16 +8,23 @@ from riak import RiakError from riak.codecs.pbuf import PbufCodec from riak.security import SecurityError, USE_STDLIB_SSL +from riak.transports.pool import BadResource -if not USE_STDLIB_SSL: - from OpenSSL.SSL import Connection - from riak.transports.security import configure_pyopenssl_context -else: +if USE_STDLIB_SSL: import ssl from riak.transports.security import configure_ssl_context +else: + from OpenSSL.SSL import Connection + from riak.transports.security import configure_pyopenssl_context class TcpConnection(object): + # These are set in the TcpTransport initializer + _address = None + _timeout = None + _socket_keepalive = None + _socket_tcp_options = None + """ Connection-related methods for TcpTransport. """ @@ -174,6 +182,10 @@ def _recv(self, msglen): toread = msglen while toread: nbytes = self._socket.recv_into(view, toread) + # https://docs.python.org/2/howto/sockets.html#using-a-socket + # https://github.com/basho/riak-python-client/issues/399 + if nbytes == 0: + raise BadResource('recv_into returned zero bytes unexpectedly') view = view[nbytes:] # slicing views is cheap toread -= nbytes nread += nbytes @@ -189,6 +201,13 @@ def _connect(self): self._timeout) else: self._socket = socket.create_connection(self._address) + if self._socket_tcp_options: + ka_opts = self._socket_tcp_options + for k, v in ka_opts.iteritems(): + self._socket.setsockopt(socket.SOL_TCP, k, v) + if self._socket_keepalive: + self._socket.setsockopt( + socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if self._client._credentials: self._init_security() @@ -197,9 +216,15 @@ def close(self): Closes the underlying socket of the PB connection. """ if self._socket: + if USE_STDLIB_SSL: + # NB: Python 2.7.8 and earlier does not have a compatible + # shutdown() method due to the SSL lib + try: + self._socket.shutdown(socket.SHUT_RDWR) + except IOError as e: + # NB: sometimes this is the exception if the initial + # connection didn't succeed correctly + if e.errno != errno.EBADF: + raise self._socket.close() del self._socket - - # These are set in the TcpTransport initializer - _address = None - _timeout = None diff --git a/riak/transports/tcp/transport.py b/riak/transports/tcp/transport.py index 58420767..7f440d7c 100644 --- a/riak/transports/tcp/transport.py +++ b/riak/transports/tcp/transport.py @@ -35,7 +35,12 @@ def __init__(self, self._socket = None self._pbuf_c = None self._ttb_c = None - self._use_ttb = kwargs.get('use_ttb', True) + self._socket_tcp_options = \ + kwargs.get('socket_tcp_options', {}) + self._socket_keepalive = \ + kwargs.get('socket_keepalive', False) + self._use_ttb = \ + kwargs.get('use_ttb', True) def _get_pbuf_codec(self): if not self._pbuf_c: