Skip to content
4 changes: 2 additions & 2 deletions riak/client/transport.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from riak.transports.pool import BadResource
from riak.transports.tcp import is_retryable as is_pbc_retryable
from riak.transports.tcp import is_retryable as is_tcp_retryable
from riak.transports.http import is_retryable as is_http_retryable
import threading
from six import PY2
Expand Down Expand Up @@ -162,7 +162,7 @@ def _is_retryable(error):
:type error: Exception
:rtype: boolean
"""
return is_pbc_retryable(error) or is_http_retryable(error)
return is_tcp_retryable(error) or is_http_retryable(error)


def retryable(fn, protocol=None):
Expand Down
10 changes: 7 additions & 3 deletions riak/riak_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ class RiakError(Exception):
"""
Base class for exceptions generated in the Riak API.
"""
def __init__(self, value):
self.value = value
def __init__(self, *args, **kwargs):
super(RiakError, self).__init__(*args, **kwargs)
if len(args) > 0:
self.value = args[0]
else:
self.value = 'unknown'

def __str__(self):
return repr(self.value)
Expand All @@ -34,5 +38,5 @@ class ConflictError(RiakError):
:class:`~riak.riak_object.RiakObject` that has more than one
sibling.
"""
def __init__(self, message="Object in conflict"):
def __init__(self, message='Object in conflict'):
super(ConflictError, self).__init__(message)
19 changes: 12 additions & 7 deletions riak/tests/test_btypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,18 @@ def test_multiget_bucket_types(self):
self.assertEqual(btype, mobj.bucket.bucket_type)

def test_write_once_bucket_type(self):
btype = self.client.bucket_type('write_once')
bucket = btype.bucket(self.bucket_name)

for i in range(100):
obj = bucket.new(self.key_name + str(i))
obj.data = {'id': i}
obj.store()
bt = 'write_once'
skey = 'write_once-init'
btype = self.client.bucket_type(bt)
bucket = btype.bucket(bt)
sobj = bucket.get(skey)
if not sobj.exists:
for i in range(100):
o = bucket.new(self.key_name + str(i))
o.data = {'id': i}
o.store()
o = bucket.new(skey, data={'id': skey})
o.store()

mget = bucket.multiget([self.key_name + str(i) for i in range(100)])
for mobj in mget:
Expand Down
12 changes: 12 additions & 0 deletions riak/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from six import PY2
from threading import Thread
from riak.riak_object import RiakObject
from riak.transports.tcp import TcpTransport
from riak.tests import DUMMY_HTTP_PORT, DUMMY_PB_PORT, RUN_POOL
from riak.tests.base import IntegrationTestBase

Expand All @@ -13,6 +14,17 @@


class ClientTests(IntegrationTestBase, unittest.TestCase):
def test_can_set_tcp_keepalive(self):
if self.protocol == 'pbc':
topts = {'socket_keepalive': True}
c = self.create_client(transport_options=topts)
for i, r in enumerate(c._tcp_pool.resources):
self.assertIsInstance(r, TcpTransport)
self.assertTrue(r._socket_keepalive)
c.close()
else:
pass

def test_uses_client_id_if_given(self):
if self.protocol == 'pbc':
zero_client_id = "\0\0\0\0"
Expand Down
25 changes: 18 additions & 7 deletions riak/tests/test_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,29 @@ def test_string_bucket_name(self):
def test_generate_key(self):
# Ensure that Riak generates a random key when
# the key passed to bucket.new() is None.
bucket = self.client.bucket('random_key_bucket')
existing_keys = bucket.get_keys()
bucket = self.client.bucket(self.bucket_name)
o = bucket.new(None, data={})
self.assertIsNone(o.key)
o.store()
self.assertIsNotNone(o.key)
self.assertNotIn('/', o.key)
self.assertNotIn(o.key, existing_keys)
self.assertEqual(len(bucket.get_keys()), len(existing_keys) + 1)
existing_keys = bucket.get_keys()
self.assertEqual(len(existing_keys), 1)

def maybe_store_keys(self):
skey = 'rkb-init'
bucket = self.client.bucket('random_key_bucket')
sobj = bucket.get(skey)
if sobj.exists:
return
for key in range(1, 1000):
o = bucket.new(None, data={})
o.store()
o = bucket.new(skey, data={})
o.store()

def test_stream_keys(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
regular_keys = bucket.get_keys()
self.assertNotEqual(len(regular_keys), 0)
Expand All @@ -203,10 +215,8 @@ def test_stream_keys(self):
self.assertEqual(sorted(regular_keys), sorted(streamed_keys))

def test_stream_keys_timeout(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
for key in range(1, 1000):
o = bucket.new(None, data={})
o.store()
streamed_keys = []
with self.assertRaises(RiakError):
for keylist in self.client.stream_keys(bucket, timeout=1):
Expand All @@ -216,6 +226,7 @@ def test_stream_keys_timeout(self):
streamed_keys += keylist

def test_stream_keys_abort(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
regular_keys = bucket.get_keys()
self.assertNotEqual(len(regular_keys), 0)
Expand Down
2 changes: 1 addition & 1 deletion riak/transports/tcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def destroy_resource(self, tcp):
def is_retryable(err):
"""
Determines if the given exception is something that is
network/socket-related and should thus cause the PBC connection to
network/socket-related and should thus cause the TCP connection to
close and the operation retried on another node.

:rtype: boolean
Expand Down
41 changes: 33 additions & 8 deletions riak/transports/tcp/connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import socket
import struct

Expand All @@ -7,16 +8,23 @@
from riak import RiakError
from riak.codecs.pbuf import PbufCodec
from riak.security import SecurityError, USE_STDLIB_SSL
from riak.transports.pool import BadResource

if not USE_STDLIB_SSL:
from OpenSSL.SSL import Connection
from riak.transports.security import configure_pyopenssl_context
else:
if USE_STDLIB_SSL:
import ssl
from riak.transports.security import configure_ssl_context
else:
from OpenSSL.SSL import Connection
from riak.transports.security import configure_pyopenssl_context


class TcpConnection(object):
# These are set in the TcpTransport initializer
_address = None
_timeout = None
_socket_keepalive = None
_socket_tcp_options = None

"""
Connection-related methods for TcpTransport.
"""
Expand Down Expand Up @@ -174,6 +182,10 @@ def _recv(self, msglen):
toread = msglen
while toread:
nbytes = self._socket.recv_into(view, toread)
# https://docs.python.org/2/howto/sockets.html#using-a-socket
# https://github.com/basho/riak-python-client/issues/399
if nbytes == 0:
raise BadResource('recv_into returned zero bytes unexpectedly')
view = view[nbytes:] # slicing views is cheap
toread -= nbytes
nread += nbytes
Expand All @@ -189,6 +201,13 @@ def _connect(self):
self._timeout)
else:
self._socket = socket.create_connection(self._address)
if self._socket_tcp_options:
ka_opts = self._socket_tcp_options
for k, v in ka_opts.iteritems():
self._socket.setsockopt(socket.SOL_TCP, k, v)
if self._socket_keepalive:
self._socket.setsockopt(
socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if self._client._credentials:
self._init_security()

Expand All @@ -197,9 +216,15 @@ def close(self):
Closes the underlying socket of the PB connection.
"""
if self._socket:
if USE_STDLIB_SSL:
# NB: Python 2.7.8 and earlier does not have a compatible
# shutdown() method due to the SSL lib
try:
self._socket.shutdown(socket.SHUT_RDWR)
except IOError as e:
# NB: sometimes this is the exception if the initial
# connection didn't succeed correctly
if e.errno != errno.EBADF:
raise
self._socket.close()
del self._socket

# These are set in the TcpTransport initializer
_address = None
_timeout = None
7 changes: 6 additions & 1 deletion riak/transports/tcp/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def __init__(self,
self._socket = None
self._pbuf_c = None
self._ttb_c = None
self._use_ttb = kwargs.get('use_ttb', True)
self._socket_tcp_options = \
kwargs.get('socket_tcp_options', {})
self._socket_keepalive = \
kwargs.get('socket_keepalive', False)
self._use_ttb = \
kwargs.get('use_ttb', True)

def _get_pbuf_codec(self):
if not self._pbuf_c:
Expand Down