Skip to content
Open

WIP #193

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pynuodb/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_fetch.c
936 changes: 936 additions & 0 deletions pynuodb/_fetch.pyx

Large diffs are not rendered by default.

68 changes: 36 additions & 32 deletions pynuodb/crypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,50 +142,54 @@ def fromHex(hexStr):

def toSignedByteString(value):
# type: (int) -> bytearray
"""Convert an integer into bytes."""
result = bytearray()
if value == 0 or value == -1:
result.append(value & 0xFF)
"""Convert an integer into the minimal big-endian two's-complement bytes.

Uses int.to_bytes (C-level) instead of a Python byte-shift loop. For
negative values that are not exact -2**k the standard byte-length formula
(bit_length+8)//8 would over-allocate by one; the (-v-1).bit_length()
expression below computes the right minimal width without that case split.
"""
if value == 0:
return bytearray(b'\x00')
if value == -1:
return bytearray(b'\xff')
if value > 0:
nbytes = (value.bit_length() + 8) // 8
else:
while value != 0 and value != -1:
result.append(value & 0xFF)
value >>= 8
# Zero pad if positive
if value == 0 and (result[-1] & 0x80) == 0x80:
result.append(0x00)
elif value == -1 and (result[-1] & 0x80) == 0x00:
result.append(0xFF)
result.reverse()
return result
nbytes = ((-value - 1).bit_length() + 8) // 8
return bytearray(value.to_bytes(nbytes, 'big', signed=True))


def fromSignedByteString(data):
# type: (bytearray) -> int
"""Convert bytes into a signed integer."""
if data:
is_neg = (data[0] & 0x80) >> 7
else:
is_neg = 0
result = 0
shiftCount = 0
for b in reversed(data):
result = result | (((b & 0xFF) ^ (is_neg * 0xFF)) << shiftCount)
shiftCount += 8

return ((-1)**is_neg) * (result + is_neg)
return int.from_bytes(data, 'big', signed=True)


def toByteString(bigInt):
# type: (int) -> bytearray
"""Convert an integer into bytes."""
"""Convert a non-negative integer into the minimal big-endian bytes.

The legacy implementation also accepted negative inputs and produced a
quirky truncated two's-complement representation, but no current caller
invokes it that way (lengths, scales, message IDs, SRP primes are all
non-negative). We preserve the original behaviour for 0 and -1 (one byte
of 0x00 / 0xff) and fall back to the old algorithm for any other negative
value so external semantics stay byte-identical.
"""
if bigInt == 0:
return bytearray(b'\x00')
if bigInt == -1:
return bytearray(b'\xff')
if bigInt > 0:
nbytes = (bigInt.bit_length() + 7) // 8
return bytearray(bigInt.to_bytes(nbytes, 'big'))
# Negative-other-than-(-1) fallback (unused in practice).
result = bytearray()
if bigInt == -1 or bigInt == 0:
while bigInt != 0 and bigInt != -1:
result.append(bigInt & 0xFF)
else:
while bigInt != 0 and bigInt != -1:
result.append(bigInt & 0xFF)
bigInt >>= 8
result.reverse()
bigInt >>= 8
result.reverse()
return result


Expand Down
24 changes: 18 additions & 6 deletions pynuodb/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def executemany(self, operation, seq_of_parameters):
def fetchone(self):
# type: () -> Optional[result_set.Row]
"""Return the next row of results from the previous SQL operation."""
self._check_closed()
# Inline _check_closed to avoid per-row function-call overhead.
if self.closed:
raise Error("cursor is closed")
if self.session.closed:
raise Error("connection is closed")
if self._result_set is None:
raise Error("Previous execute did not produce any results or no call was issued yet")
self.rownumber += 1
Expand Down Expand Up @@ -218,14 +222,22 @@ def fetchall(self):
# type: () -> List[result_set.Row]
"""Return all rows generated by the previous SQL operation."""
self._check_closed()
if self._result_set is None:
raise Error("Previous execute did not produce any results or no call was issued yet")

fetched_rows = []
fetched_rows = [] # type: List[result_set.Row]
while True:
row = self.fetchone()
if row is None:
# Drain the current in-memory batch in one shot instead of calling
# fetchone() per row.
idx = self._result_set.results_idx
batch = self._result_set.results
if idx < len(batch):
fetched_rows.extend(batch[idx:])
self._result_set.results_idx = len(batch)
if self._result_set.complete:
break
else:
fetched_rows.append(row)
self.session.fetch_result_set_next(self._result_set)
self.rownumber += len(fetched_rows)
return fetched_rows

def nextset(self): # pylint: disable=no-self-use
Expand Down
112 changes: 86 additions & 26 deletions pynuodb/encodedsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
from . import result_set
from .datatype import LOCALZONE_NAME

try:
from . import _fetch as _fetch_accel
_HAVE_FETCH_ACCEL = True
except ImportError:
_HAVE_FETCH_ACCEL = False

# ZoneInfo is preferred but not introduced until 3.9
if sys.version_info >= (3, 9):
# preferred python >= 3.9
Expand Down Expand Up @@ -465,16 +471,23 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists):
"""Batch the prepared statement with the given parameters."""
self._setup_statement(prepared_statement.handle, protocol.EXECUTEBATCHPREPAREDSTATEMENT)

for parameters in param_lists:
plen = len(parameters)
if prepared_statement.parameter_count != plen:
raise ProgrammingError("Incorrect number of parameters specified,"
" expected %d, got %d"
% (prepared_statement.parameter_count,
plen))
self.putInt(plen)
for param in parameters:
self.putValue(param)
expected = prepared_statement.parameter_count
if _HAVE_FETCH_ACCEL:
# encode_batch needs a list (it indexes / lens it).
if not isinstance(param_lists, list):
param_lists = list(param_lists)
_fetch_accel.encode_batch(
self.__output, param_lists, expected, self)
else:
for parameters in param_lists:
plen = len(parameters)
if expected != plen:
raise ProgrammingError("Incorrect number of parameters specified,"
" expected %d, got %d"
% (expected, plen))
self.putInt(plen)
for param in parameters:
self.putValue(param)
self.putInt(-1)
self.putInt(len(param_lists))
self._exchangeMessages()
Expand Down Expand Up @@ -512,25 +525,46 @@ def fetch_result_set(self, stmt):
for _ in range(colcount):
self.getString()

complete = False
init_results = [] # type: List[result_set.Row]

# If we hit the end of the stream without next==0, there are more
# results to fetch.
while self._hasBytes(1):
next_row = self.getInt()
if next_row == 0:
complete = True
break
if _HAVE_FETCH_ACCEL:
# The initial row batch has the same wire format as subsequent
# NEXT batches; reuse the Cython decoder rather than calling
# getValue() per cell through the Python path.
pos, complete = _fetch_accel.decode_next_batch(
self.__input, self.__inpos, colcount,
init_results, self._cython_exotic_decode,
self.timezone_info)
self.__inpos = pos
else:
complete = False
# If we hit the end of the stream without next==0, there are more
# results to fetch.
while self._hasBytes(1):
next_row = self.getInt()
if next_row == 0:
complete = True
break

row = [None] * colcount
for i in range(colcount):
row[i] = self.getValue()
row = [None] * colcount
for i in range(colcount):
row[i] = self.getValue()

init_results.append(tuple(row))
init_results.append(tuple(row))

return result_set.ResultSet(handle, colcount, init_results, complete)

def _cython_exotic_decode(self, pos):
# type: (int) -> tuple
"""Bridge called by _fetch_accel.decode_next_batch for exotic wire types.

Sets __inpos to pos, calls getValue() (which handles any NuoDB type),
then returns (value, new_pos) so the Cython loop can resume.
"""
self.__inpos = pos
val = self.getValue()
return val, self.__inpos

def fetch_result_set_next(self, resultset):
# type: (result_set.ResultSet) -> None
"""Get more rows from this result set."""
Expand All @@ -539,6 +573,15 @@ def fetch_result_set_next(self, resultset):

resultset.clear_results()

if _HAVE_FETCH_ACCEL:
pos, complete = _fetch_accel.decode_next_batch(
self.__input, self.__inpos, resultset.col_count,
resultset.results, self._cython_exotic_decode,
self.timezone_info)
self.__inpos = pos
resultset.complete = complete
return

while self._hasBytes(1):
if self.getInt() == 0:
resultset.complete = True
Expand Down Expand Up @@ -869,6 +912,24 @@ def putValue(self, value): # pylint: disable=too-many-return-statements
if value is None:
return self.putNull()

# Fast paths: `type(v) is X` is a C-level pointer compare; isinstance
# walks the MRO and is markedly slower. These hit on the bulk of
# bound parameters (plain int / str / float).
tv = type(value)
if tv is int:
return self.putInt(value)
if tv is str:
return self.putString(value)
if tv is float:
return self.putDouble(value)
if tv is bool:
# Preserve historic wire behaviour: bools encode as integers
# because the original isinstance(value, int) chain matched True
# and False before reaching the (dead) bool branch below.
return self.putInt(value)

# Subclass-aware fallback for the long tail (int/str subclasses,
# Decimal, datetime types, Binary, Vector, etc.).
if isinstance(value, int):
return self.putInt(value)

Expand All @@ -891,9 +952,6 @@ def putValue(self, value): # pylint: disable=too-many-return-statements
if isinstance(value, datatype.Binary):
return self.putOpaque(value)

if isinstance(value, bool):
return self.putBoolean(value)

# we don't want to autodetect lists as being VECTOR, so we
# only bind double if it is the explicit type
if isinstance(value, datatype.Vector):
Expand All @@ -919,6 +977,7 @@ def getInt(self):

raise DataError('Not an integer: %d' % (code))


# Does not preserve E notation
def getScaledInt(self):
# type: () -> decimal.Decimal
Expand Down Expand Up @@ -1311,7 +1370,8 @@ def _exchangeMessages(self, getResponse=True):
resp = self.recv(timeout=None)
if resp is None:
db_error_handler(protocol.OPERATION_TIMEOUT, "timed out")
self.__input = crypt.bytesToArray(resp)
# recv() now returns bytearray directly; no copy needed.
self.__input = resp

error = self.getInt()
if error != 0:
Expand Down
9 changes: 9 additions & 0 deletions pynuodb/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ def fetchone(self):
res = self.results[self.results_idx]
self.results_idx += 1
return res


# Replace the Python implementation above with the Cython cdef class when the
# extension has been built. The interface is identical; fetchone() and
# is_complete() become near-C-speed cpdef calls.
try:
from ._fetch import ResultSet # noqa: F811 pylint: disable=unused-import
except ImportError:
pass
Loading