Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dbus_next/_private/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum

PROTOCOL_VERSION = 1
MESSAGE_HEADER_LEN = 16

LITTLE_ENDIAN = ord('l')
BIG_ENDIAN = ord('B')
Expand Down
45 changes: 36 additions & 9 deletions dbus_next/_private/unmarshaller.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ..message import Message
from .constants import HeaderField, LITTLE_ENDIAN, BIG_ENDIAN, PROTOCOL_VERSION
from .constants import HeaderField, LITTLE_ENDIAN, BIG_ENDIAN, PROTOCOL_VERSION, MESSAGE_HEADER_LEN
from ..constants import MessageType, MessageFlag
from ..signature import SignatureTree, Variant
from ..errors import InvalidMessageError

import array
import socket
from struct import unpack


Expand All @@ -12,10 +14,12 @@ class MarshallerStreamEndError(Exception):


class Unmarshaller:
def __init__(self, stream):
def __init__(self, stream, sock=None):
self.unix_fds = []
self.buf = bytearray()
self.offset = 0
self.stream = stream
self.sock = sock
self.endian = None
self.message = None

Expand Down Expand Up @@ -66,6 +70,25 @@ def align(self, n):
return b''
return self.read(padding)

def read_endian(self, _=None):
unix_fds = array.array("i")

if self.sock:
try:
Copy link
Copy Markdown
Member

@acrisci acrisci Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path might run multiple times for a single unmarshalling of a message. That will happen in the case there was a blocking error. The previously read data will be read from a buffer cache instead of the socket.

One solution would be to cache the fds on the unmarshaller instance and then skip reading the fds from the socket in the case that the fd cache is present.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that cache to unmarshaller, please have a look if this is what you had in mind.

msg, ancdata, *_ = self.sock.recvmsg(MESSAGE_HEADER_LEN,
socket.CMSG_LEN(16 * unix_fds.itemsize),
socket.MSG_PEEK)
except BlockingIOError:
return self.read_byte(), list(unix_fds)

for level, type, data in ancdata:
if not (level == socket.SOL_SOCKET and type == socket.SCM_RIGHTS):
continue

unix_fds.frombytes(data[:len(data) - (len(data) % unix_fds.itemsize)])

return self.read_byte(), list(unix_fds)

def read_byte(self, _=None):
return self.read(1)[0]

Expand Down Expand Up @@ -195,7 +218,10 @@ def read_argument(self, type_):

def _unmarshall(self):
self.offset = 0
self.endian = self.read_byte()
self.endian, unix_fds = self.read_endian()
if not self.unix_fds:
self.unix_fds = unix_fds

if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
raise InvalidMessageError('Expecting endianness as the first byte')
message_type = MessageType(self.read_byte())
Expand All @@ -209,13 +235,13 @@ def _unmarshall(self):
body_len = self.read_uint32()
serial = self.read_uint32()

header_fields = {HeaderField.UNIX_FDS.name: []}
header_fields = {}
for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]):
field = HeaderField(field_struct[0])
if field == HeaderField.UNIX_FDS:
header_fields[field.name].append(field_struct[1].value)
else:
header_fields[field.name] = field_struct[1].value
continue

header_fields[field.name] = field_struct[1].value

self.align(8)

Expand All @@ -228,7 +254,8 @@ def _unmarshall(self):
sender = header_fields.get(HeaderField.SENDER.name)
signature = header_fields.get(HeaderField.SIGNATURE.name, '')
signature_tree = SignatureTree(signature)
unix_fds = header_fields.get(HeaderField.UNIX_FDS.name)
# TODO: check unix_fds against the header
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to check here

# unix_fds = header_fields.get(HeaderField.UNIX_FDS.name)

body = []

Expand All @@ -245,7 +272,7 @@ def _unmarshall(self):
error_name=error_name,
reply_serial=reply_serial,
sender=sender,
unix_fds=unix_fds,
unix_fds=self.unix_fds,
signature=signature_tree,
body=body,
serial=serial)
Expand Down
51 changes: 45 additions & 6 deletions dbus_next/aio/message_bus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..message_bus import BaseMessageBus
from .._private.unmarshaller import Unmarshaller
from .._private.constants import MESSAGE_HEADER_LEN
from ..message import Message
from ..constants import BusType, NameFlag, RequestNameReply, ReleaseNameReply, MessageType, MessageFlag
from ..service import ServiceInterface
Expand All @@ -8,8 +9,10 @@
from .. import introspection as intr
from ..auth import Authenticator, AuthExternal

import array
import logging
import asyncio
import socket
import traceback
from typing import Optional

Expand Down Expand Up @@ -43,9 +46,9 @@ def __init__(self,
auth: Authenticator = None):
super().__init__(bus_address, bus_type, ProxyObject)
self._loop = asyncio.get_event_loop()
self._unmarshaller = Unmarshaller(self._stream)
self._unmarshaller = Unmarshaller(self._stream, self._sock)
if auth is None:
self._auth = AuthExternal()
self._auth = AuthExternal(enable_fds=True)
else:
self._auth = auth

Expand Down Expand Up @@ -222,6 +225,32 @@ def reply_handler(reply, err):

return future.result()

def _sock_sendmsg(self, sock, *buffers, ancdata=None, flags=0):
fd = sock.fileno()
fut = asyncio.futures.Future(loop=self._loop)

def __sock_sendmsg(registered=False):
if registered:
self._loop.remove_writer(fd)

if fut.cancelled():
return

try:
size = sock.sendmsg(buffers, ancdata or [], flags)
except (BlockingIOError, InterruptedError):
self._loop.add_writer(fd, __sock_sendmsg, True)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(size)

if self._loop._debug and sock.gettimeout() != 0:
raise ValueError('Socket %r must be non-blocking' % sock)

__sock_sendmsg()
return fut

def send(self, msg: Message):
if not msg.serial:
msg.serial = self.next_serial()
Expand All @@ -231,7 +260,16 @@ def send(self, msg: Message):
self._buffered_messages.append(msg)
return

asyncio.ensure_future(self._loop.sock_sendall(self._sock, msg._marshall()))
buf = msg._marshall()

async def _send():
ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", msg.unix_fds))] \
if msg.unix_fds else None

await self._sock_sendmsg(self._sock, buf[:MESSAGE_HEADER_LEN], ancdata=ancdata)
await self._loop.sock_sendall(self._sock, buf[MESSAGE_HEADER_LEN:])

asyncio.ensure_future(_send())

def get_proxy_object(self, bus_name: str, path: str, introspection: intr.Node) -> ProxyObject:
return super().get_proxy_object(bus_name, path, introspection)
Expand All @@ -245,8 +283,9 @@ def handler(msg, send_reply):
def done(fut):
with send_reply:
result = fut.result()
body = ServiceInterface._fn_result_to_body(result, method.out_signature_tree)
send_reply(Message.new_method_return(msg, method.out_signature, body))
body, fds = ServiceInterface._fn_result_to_body(result,
method.out_signature_tree)
send_reply(Message.new_method_return(msg, method.out_signature, body, fds))

fut = asyncio.ensure_future(method.fn(interface, *msg.body))
fut.add_done_callback(done)
Expand All @@ -258,7 +297,7 @@ def _message_reader(self):
while True:
if self._unmarshaller.unmarshall():
self._on_message(self._unmarshaller.message)
self._unmarshaller = Unmarshaller(self._stream)
self._unmarshaller = Unmarshaller(self._stream, self._sock)
else:
break
except Exception as e:
Expand Down
9 changes: 8 additions & 1 deletion dbus_next/aio/proxy_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..proxy_object import BaseProxyObject, BaseProxyInterface
from ..message_bus import BaseMessageBus
from ..message import Message, MessageFlag
from ..message import Message, MessageFlag, _replace_fds
from ..signature import Variant
from ..errors import DBusError
from ..constants import ErrorType
Expand Down Expand Up @@ -82,6 +82,13 @@ async def method_fn(*args, flags=MessageFlag.NONE):
BaseProxyInterface._check_method_return(msg, intr_method.out_signature)

out_len = len(intr_method.out_args)

def _replace(obj):
return msg.unix_fds[obj]

if any(sig in msg.signature for sig in 'hv'):
_replace_fds(msg.body, msg.signature_tree.types, _replace)

if not out_len:
return None
elif out_len == 1:
Expand Down
21 changes: 18 additions & 3 deletions dbus_next/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,32 @@ class AuthExternal(Authenticator):

:sealso: https://dbus.freedesktop.org/doc/dbus-specification.html#auth-protocol
"""
def __init__(self, enable_fds=False):
self.enable_fds = enable_fds
self.negotiating_fds = False

def _authentication_start(self) -> str:
hex_uid = str(os.getuid()).encode().hex()
return f'AUTH EXTERNAL {hex_uid}'

def _receive_line(self, line: str):
response, args = _AuthResponse.parse(line)

if response != _AuthResponse.OK:
raise AuthError(f'authentication failed: {response.value}: {args}')
if response is _AuthResponse.OK:
if self.enable_fds:
self.NEGOTIATING_FDS = True
return "NEGOTIATE_UNIX_FD"
else:
return "BEGIN"

return 'BEGIN'
if response is _AuthResponse.AGREE_UNIX_FD:
return "BEGIN"

if response is _AuthResponse.ERROR and self.NEGOTIATING_FDS:
# TODO: logger?
return "BEGIN"

raise AuthError(f'authentication failed: {response.value}: {args}')


class AuthAnnonymous(Authenticator):
Expand Down
60 changes: 53 additions & 7 deletions dbus_next/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .errors import InvalidMessageError
from .signature import SignatureTree, Variant

from typing import List, Any
from typing import List, Any, Union, IO


class Message:
Expand Down Expand Up @@ -135,7 +135,10 @@ def new_error(msg: 'Message', error_name: str, error_text: str) -> 'Message':
body=[error_text])

@staticmethod
def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = []) -> 'Message':
def new_method_return(msg: 'Message',
signature: str = '',
body: List[Any] = [],
fds: List[Union[int, IO]] = []) -> 'Message':
"""A convenience constructor to create a method return to the given method call message.

:param msg: The method call message this is a reply to.
Expand All @@ -144,6 +147,8 @@ def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = [])
:type signature: str
:param body: The body of this message. Must match the signature.
:type body: list(Any)
:param fds: List of filelike objects or integers representing file descriptors.
:type body: list(file object or int)

:returns: The method return message
:rtype: :class:`Message`
Expand All @@ -155,14 +160,16 @@ def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = [])
reply_serial=msg.serial,
destination=msg.sender,
signature=signature,
body=body)
body=body,
unix_fds=fds)

@staticmethod
def new_signal(path: str,
interface: str,
member: str,
signature: str = '',
body: List[Any] = None) -> 'Message':
body: List[Any] = None,
fds: List[Any] = None) -> 'Message':
"""A convenience constructor to create a new signal message.

:param path: The path of this signal.
Expand Down Expand Up @@ -191,7 +198,8 @@ def new_signal(path: str,
path=path,
member=member,
signature=signature,
body=body)
body=body,
unix_fds=fds)

def _matches(self, **kwargs):
for attr, val in kwargs.items():
Expand Down Expand Up @@ -221,8 +229,8 @@ def _marshall(self):
fields.append([HeaderField.DESTINATION.value, Variant('s', self.destination)])
if self.signature:
fields.append([HeaderField.SIGNATURE.value, Variant('g', self.signature)])
for fd in self.unix_fds:
fields.append([HeaderField.UNIX_FDS.value, Variant('h', fd)])
if self.unix_fds:
fields.append([HeaderField.UNIX_FDS.value, Variant('u', len(self.unix_fds))])

header_body = [
LITTLE_ENDIAN, self.message_type.value, self.flags.value, PROTOCOL_VERSION,
Expand All @@ -232,3 +240,41 @@ def _marshall(self):
header_block.marshall()
header_block.align(8)
return header_block.buffer + body_block.buffer


def _replace_fds(body_obj, children, replace_fn):
for index, st in enumerate(children):
if not any(sig in st.signature for sig in 'hv'):
continue
if st.signature == 'h':
body_obj[index] = replace_fn(body_obj[index])
elif st.token == 'a':
if st.children[0].token == '{':
_replace_fds(body_obj[index], st.children, replace_fn)
else:
for i, child in enumerate(body_obj[index]):
if st.signature == 'ah':
body_obj[index][i] = replace_fn(child)
else:
_replace_fds([child], st.children, replace_fn)
elif st.token in '(':
_replace_fds(body_obj[index], st.children, replace_fn)
elif st.token in '{':
for key, value in list(body_obj.items()):
body_obj.pop(key)
if st.children[0].signature == 'h':
key = replace_fn(key)
if st.children[1].signature == 'h':
value = replace_fn(value)
else:
_replace_fds([value], [st.children[1]], replace_fn)
body_obj[key] = value

elif st.signature == 'v':
if body_obj[index].signature == 'h':
body_obj[index].value = replace_fn(body_obj[index].value)
else:
_replace_fds([body_obj[index].value], [body_obj[index].type], replace_fn)

elif st.children:
_replace_fds(body_obj[index], st.children, replace_fn)
Loading