diff --git a/dbus_next/_private/unmarshaller.py b/dbus_next/_private/unmarshaller.py index 2bf0b2a..0b98a7e 100644 --- a/dbus_next/_private/unmarshaller.py +++ b/dbus_next/_private/unmarshaller.py @@ -4,7 +4,8 @@ from ..signature import SignatureTree, Variant from ..errors import InvalidMessageError -from struct import unpack +from codecs import decode +from struct import unpack_from class MarshallerStreamEndError(Exception): @@ -30,7 +31,7 @@ def __init__(self, stream): 't': self.read_uint64, 'd': self.read_double, 'h': self.read_uint32, - 'o': self.read_object_path, + 'o': self.read_string, 's': self.read_string, 'g': self.read_signature, 'a': self.read_array, @@ -39,35 +40,45 @@ def __init__(self, stream): 'v': self.read_variant } - def read(self, n): + def read(self, n, prefetch=False): + """ + Read from underlying socket into buffer and advance offset accordingly. + + :arg n: + Number of bytes to read. If not enough bytes are available in the + buffer, read more from it. + :arg prefetch: + Do not update current offset after reading. + + :returns: + Previous offset (before reading). To get the actual read bytes, + use the returned value and self.buf. + """ # store previously read data in a buffer so we can resume on socket # interruptions - data = bytearray() - if self.offset < len(self.buf): - data = self.buf[self.offset:self.offset + n] - self.offset += len(data) - n -= len(data) - if n: - read = self.stream.read(n) - if read == b'': + missing_bytes = n - (len(self.buf) - self.offset) + if missing_bytes > 0: + data = self.stream.read(missing_bytes) + if data == b'': raise EOFError() - elif read is None: + elif data is None: raise MarshallerStreamEndError() - data.extend(read) - self.buf.extend(read) - if len(read) != n: + self.buf.extend(data) + if len(data) != missing_bytes: raise MarshallerStreamEndError() - self.offset += n - return bytes(data) + prev = self.offset + if not prefetch: + self.offset += n + return prev def align(self, n): padding = n - self.offset % n if padding == 0 or padding == n: - return b'' - return self.read(padding) + return + self.read(padding) def read_byte(self, _=None): - return self.read(1)[0] + return self.buf[self.read(1)] def read_boolean(self, _=None): data = self.read_uint32() @@ -77,64 +88,48 @@ def read_boolean(self, _=None): return False def read_int16(self, _=None): - self.align(2) - fmt = 'h' - data = self.read(2) - return unpack(fmt, data)[0] + return self.read_ctype('h', 2) def read_uint16(self, _=None): - self.align(2) - fmt = 'H' - data = self.read(2) - return unpack(fmt, data)[0] + return self.read_ctype('H', 2) def read_int32(self, _=None): - self.align(4) - fmt = 'i' - data = self.read(4) - return unpack(fmt, data)[0] + return self.read_ctype('i', 4) def read_uint32(self, _=None): - self.align(4) - fmt = 'I' - data = self.read(4) - return unpack(fmt, data)[0] + return self.read_ctype('I', 4) def read_int64(self, _=None): - self.align(8) - fmt = 'q' - data = self.read(8) - return unpack(fmt, data)[0] + return self.read_ctype('q', 8) def read_uint64(self, _=None): - self.align(8) - fmt = 'Q' - data = self.read(8) - return unpack(fmt, data)[0] + return self.read_ctype('Q', 8) def read_double(self, _=None): - self.align(8) - fmt = 'd' - data = self.read(8) - return unpack(fmt, data)[0] + return self.read_ctype('d', 8) - def read_object_path(self, _=None): - path_length = self.read_uint32() - data = self.read(path_length) - self.read(1) - return data.decode() + def read_ctype(self, fmt, size): + self.align(size) + if self.endian == LITTLE_ENDIAN: + fmt = '<' + fmt + else: + fmt = '>' + fmt + o = self.read(size) + return unpack_from(fmt, self.buf, o)[0] def read_string(self, _=None): str_length = self.read_uint32() - data = self.read(str_length) - self.read(1) - return data.decode() + o = self.read(str_length + 1) # read terminating '\0' byte as well + # avoid buffer copies when slicing + str_mem_slice = memoryview(self.buf)[o:o + str_length] + return decode(str_mem_slice) def read_signature(self, _=None): signature_len = self.read_byte() - data = self.read(signature_len) - self.read(1) - return data.decode() + o = self.read(signature_len + 1) # read terminating '\0' byte as well + # avoid buffer copies when slicing + sig_mem_slice = memoryview(self.buf)[o:o + signature_len] + return decode(sig_mem_slice) def read_variant(self, _=None): signature = self.read_signature() @@ -177,7 +172,10 @@ def read_array(self, type_): key, value = self.read_dict_entry(child_type) result[key] = value elif child_type.token == 'y': - result = self.read(array_length) + o = self.read(array_length) + # avoid buffer copies when slicing + array_mem_slice = memoryview(self.buf)[o:o + array_length] + result = array_mem_slice.tobytes() else: result = [] while self.offset - beginning_offset < array_length: @@ -195,6 +193,7 @@ def read_argument(self, type_): def _unmarshall(self): self.offset = 0 + self.read(16, prefetch=True) self.endian = self.read_byte() if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN: raise InvalidMessageError('Expecting endianness as the first byte') @@ -209,6 +208,11 @@ def _unmarshall(self): body_len = self.read_uint32() serial = self.read_uint32() + header_len = self.read_uint32() + self.read(header_len + body_len, prefetch=True) + # backtrack offset since header array length needs to be read again + self.offset -= 4 + header_fields = {HeaderField.UNIX_FDS.name: []} for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]): field = HeaderField(field_struct[0])