Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 64 additions & 60 deletions dbus_next/_private/unmarshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from ..signature import SignatureTree, Variant
from ..errors import InvalidMessageError

from struct import unpack
from codecs import decode
from struct import unpack_from


class MarshallerStreamEndError(Exception):
Expand All @@ -30,7 +31,7 @@ def __init__(self, stream):
't': self.read_uint64,
'd': self.read_double,
'h': self.read_uint32,
'o': self.read_object_path,
'o': self.read_string,
's': self.read_string,
'g': self.read_signature,
'a': self.read_array,
Expand All @@ -39,35 +40,45 @@ def __init__(self, stream):
'v': self.read_variant
}

def read(self, n):
def read(self, n, prefetch=False):
"""
Read from underlying socket into buffer and advance offset accordingly.

:arg n:
Number of bytes to read. If not enough bytes are available in the
buffer, read more from it.
:arg prefetch:
Do not update current offset after reading.

:returns:
Previous offset (before reading). To get the actual read bytes,
use the returned value and self.buf.
"""
# store previously read data in a buffer so we can resume on socket
# interruptions
data = bytearray()
if self.offset < len(self.buf):
data = self.buf[self.offset:self.offset + n]
self.offset += len(data)
n -= len(data)
if n:
read = self.stream.read(n)
if read == b'':
missing_bytes = n - (len(self.buf) - self.offset)
if missing_bytes > 0:
data = self.stream.read(missing_bytes)
if data == b'':
raise EOFError()
elif read is None:
elif data is None:
raise MarshallerStreamEndError()
data.extend(read)
self.buf.extend(read)
if len(read) != n:
self.buf.extend(data)
if len(data) != missing_bytes:
raise MarshallerStreamEndError()
self.offset += n
return bytes(data)
prev = self.offset
if not prefetch:
self.offset += n
return prev

def align(self, n):
padding = n - self.offset % n
if padding == 0 or padding == n:
return b''
return self.read(padding)
return
self.read(padding)

def read_byte(self, _=None):
return self.read(1)[0]
return self.buf[self.read(1)]

def read_boolean(self, _=None):
data = self.read_uint32()
Expand All @@ -77,64 +88,48 @@ def read_boolean(self, _=None):
return False

def read_int16(self, _=None):
self.align(2)
fmt = '<h' if self.endian == LITTLE_ENDIAN else '>h'
data = self.read(2)
return unpack(fmt, data)[0]
return self.read_ctype('h', 2)

def read_uint16(self, _=None):
self.align(2)
fmt = '<H' if self.endian == LITTLE_ENDIAN else '>H'
data = self.read(2)
return unpack(fmt, data)[0]
return self.read_ctype('H', 2)

def read_int32(self, _=None):
self.align(4)
fmt = '<i' if self.endian == LITTLE_ENDIAN else '>i'
data = self.read(4)
return unpack(fmt, data)[0]
return self.read_ctype('i', 4)

def read_uint32(self, _=None):
self.align(4)
fmt = '<I' if self.endian == LITTLE_ENDIAN else '>I'
data = self.read(4)
return unpack(fmt, data)[0]
return self.read_ctype('I', 4)

def read_int64(self, _=None):
self.align(8)
fmt = '<q' if self.endian == LITTLE_ENDIAN else '>q'
data = self.read(8)
return unpack(fmt, data)[0]
return self.read_ctype('q', 8)

def read_uint64(self, _=None):
self.align(8)
fmt = '<Q' if self.endian == LITTLE_ENDIAN else '>Q'
data = self.read(8)
return unpack(fmt, data)[0]
return self.read_ctype('Q', 8)

def read_double(self, _=None):
self.align(8)
fmt = '<d' if self.endian == LITTLE_ENDIAN else '>d'
data = self.read(8)
return unpack(fmt, data)[0]
return self.read_ctype('d', 8)

def read_object_path(self, _=None):
path_length = self.read_uint32()
data = self.read(path_length)
self.read(1)
return data.decode()
def read_ctype(self, fmt, size):
self.align(size)
if self.endian == LITTLE_ENDIAN:
fmt = '<' + fmt
else:
fmt = '>' + fmt
o = self.read(size)
return unpack_from(fmt, self.buf, o)[0]

def read_string(self, _=None):
str_length = self.read_uint32()
data = self.read(str_length)
self.read(1)
return data.decode()
o = self.read(str_length + 1) # read terminating '\0' byte as well
# avoid buffer copies when slicing
str_mem_slice = memoryview(self.buf)[o:o + str_length]
return decode(str_mem_slice)

def read_signature(self, _=None):
signature_len = self.read_byte()
data = self.read(signature_len)
self.read(1)
return data.decode()
o = self.read(signature_len + 1) # read terminating '\0' byte as well
# avoid buffer copies when slicing
sig_mem_slice = memoryview(self.buf)[o:o + signature_len]
return decode(sig_mem_slice)

def read_variant(self, _=None):
signature = self.read_signature()
Expand Down Expand Up @@ -177,7 +172,10 @@ def read_array(self, type_):
key, value = self.read_dict_entry(child_type)
result[key] = value
elif child_type.token == 'y':
result = self.read(array_length)
o = self.read(array_length)
# avoid buffer copies when slicing
array_mem_slice = memoryview(self.buf)[o:o + array_length]
result = array_mem_slice.tobytes()
else:
result = []
while self.offset - beginning_offset < array_length:
Expand All @@ -195,6 +193,7 @@ def read_argument(self, type_):

def _unmarshall(self):
self.offset = 0
self.read(16, prefetch=True)
self.endian = self.read_byte()
if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
raise InvalidMessageError('Expecting endianness as the first byte')
Expand All @@ -209,6 +208,11 @@ def _unmarshall(self):
body_len = self.read_uint32()
serial = self.read_uint32()

header_len = self.read_uint32()
self.read(header_len + body_len, prefetch=True)
# backtrack offset since header array length needs to be read again
self.offset -= 4

header_fields = {HeaderField.UNIX_FDS.name: []}
for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]):
field = HeaderField(field_struct[0])
Expand Down