diff --git a/dbus_next/_private/constants.py b/dbus_next/_private/constants.py index 22a6b80..9f9525e 100644 --- a/dbus_next/_private/constants.py +++ b/dbus_next/_private/constants.py @@ -1,6 +1,7 @@ from enum import Enum PROTOCOL_VERSION = 1 +MESSAGE_HEADER_LEN = 16 LITTLE_ENDIAN = ord('l') BIG_ENDIAN = ord('B') diff --git a/dbus_next/_private/unmarshaller.py b/dbus_next/_private/unmarshaller.py index 2bf0b2a..d2d1cc9 100644 --- a/dbus_next/_private/unmarshaller.py +++ b/dbus_next/_private/unmarshaller.py @@ -1,9 +1,11 @@ from ..message import Message -from .constants import HeaderField, LITTLE_ENDIAN, BIG_ENDIAN, PROTOCOL_VERSION +from .constants import HeaderField, LITTLE_ENDIAN, BIG_ENDIAN, PROTOCOL_VERSION, MESSAGE_HEADER_LEN from ..constants import MessageType, MessageFlag from ..signature import SignatureTree, Variant from ..errors import InvalidMessageError +import array +import socket from struct import unpack @@ -12,10 +14,12 @@ class MarshallerStreamEndError(Exception): class Unmarshaller: - def __init__(self, stream): + def __init__(self, stream, sock=None): + self.unix_fds = [] self.buf = bytearray() self.offset = 0 self.stream = stream + self.sock = sock self.endian = None self.message = None @@ -66,6 +70,25 @@ def align(self, n): return b'' return self.read(padding) + def read_endian(self, _=None): + unix_fds = array.array("i") + + if self.sock: + try: + msg, ancdata, *_ = self.sock.recvmsg(MESSAGE_HEADER_LEN, + socket.CMSG_LEN(16 * unix_fds.itemsize), + socket.MSG_PEEK) + except BlockingIOError: + return self.read_byte(), list(unix_fds) + + for level, type, data in ancdata: + if not (level == socket.SOL_SOCKET and type == socket.SCM_RIGHTS): + continue + + unix_fds.frombytes(data[:len(data) - (len(data) % unix_fds.itemsize)]) + + return self.read_byte(), list(unix_fds) + def read_byte(self, _=None): return self.read(1)[0] @@ -195,7 +218,10 @@ def read_argument(self, type_): def _unmarshall(self): self.offset = 0 - self.endian = self.read_byte() + self.endian, unix_fds = self.read_endian() + if not self.unix_fds: + self.unix_fds = unix_fds + if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN: raise InvalidMessageError('Expecting endianness as the first byte') message_type = MessageType(self.read_byte()) @@ -209,13 +235,13 @@ def _unmarshall(self): body_len = self.read_uint32() serial = self.read_uint32() - header_fields = {HeaderField.UNIX_FDS.name: []} + header_fields = {} for field_struct in self.read_argument(SignatureTree('a(yv)').types[0]): field = HeaderField(field_struct[0]) if field == HeaderField.UNIX_FDS: - header_fields[field.name].append(field_struct[1].value) - else: - header_fields[field.name] = field_struct[1].value + continue + + header_fields[field.name] = field_struct[1].value self.align(8) @@ -228,7 +254,8 @@ def _unmarshall(self): sender = header_fields.get(HeaderField.SENDER.name) signature = header_fields.get(HeaderField.SIGNATURE.name, '') signature_tree = SignatureTree(signature) - unix_fds = header_fields.get(HeaderField.UNIX_FDS.name) + # TODO: check unix_fds against the header + # unix_fds = header_fields.get(HeaderField.UNIX_FDS.name) body = [] @@ -245,7 +272,7 @@ def _unmarshall(self): error_name=error_name, reply_serial=reply_serial, sender=sender, - unix_fds=unix_fds, + unix_fds=self.unix_fds, signature=signature_tree, body=body, serial=serial) diff --git a/dbus_next/aio/message_bus.py b/dbus_next/aio/message_bus.py index ee61a23..8dc812a 100644 --- a/dbus_next/aio/message_bus.py +++ b/dbus_next/aio/message_bus.py @@ -1,5 +1,6 @@ from ..message_bus import BaseMessageBus from .._private.unmarshaller import Unmarshaller +from .._private.constants import MESSAGE_HEADER_LEN from ..message import Message from ..constants import BusType, NameFlag, RequestNameReply, ReleaseNameReply, MessageType, MessageFlag from ..service import ServiceInterface @@ -8,8 +9,10 @@ from .. import introspection as intr from ..auth import Authenticator, AuthExternal +import array import logging import asyncio +import socket import traceback from typing import Optional @@ -43,9 +46,9 @@ def __init__(self, auth: Authenticator = None): super().__init__(bus_address, bus_type, ProxyObject) self._loop = asyncio.get_event_loop() - self._unmarshaller = Unmarshaller(self._stream) + self._unmarshaller = Unmarshaller(self._stream, self._sock) if auth is None: - self._auth = AuthExternal() + self._auth = AuthExternal(enable_fds=True) else: self._auth = auth @@ -222,6 +225,32 @@ def reply_handler(reply, err): return future.result() + def _sock_sendmsg(self, sock, *buffers, ancdata=None, flags=0): + fd = sock.fileno() + fut = asyncio.futures.Future(loop=self._loop) + + def __sock_sendmsg(registered=False): + if registered: + self._loop.remove_writer(fd) + + if fut.cancelled(): + return + + try: + size = sock.sendmsg(buffers, ancdata or [], flags) + except (BlockingIOError, InterruptedError): + self._loop.add_writer(fd, __sock_sendmsg, True) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(size) + + if self._loop._debug and sock.gettimeout() != 0: + raise ValueError('Socket %r must be non-blocking' % sock) + + __sock_sendmsg() + return fut + def send(self, msg: Message): if not msg.serial: msg.serial = self.next_serial() @@ -231,7 +260,16 @@ def send(self, msg: Message): self._buffered_messages.append(msg) return - asyncio.ensure_future(self._loop.sock_sendall(self._sock, msg._marshall())) + buf = msg._marshall() + + async def _send(): + ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", msg.unix_fds))] \ + if msg.unix_fds else None + + await self._sock_sendmsg(self._sock, buf[:MESSAGE_HEADER_LEN], ancdata=ancdata) + await self._loop.sock_sendall(self._sock, buf[MESSAGE_HEADER_LEN:]) + + asyncio.ensure_future(_send()) def get_proxy_object(self, bus_name: str, path: str, introspection: intr.Node) -> ProxyObject: return super().get_proxy_object(bus_name, path, introspection) @@ -245,8 +283,9 @@ def handler(msg, send_reply): def done(fut): with send_reply: result = fut.result() - body = ServiceInterface._fn_result_to_body(result, method.out_signature_tree) - send_reply(Message.new_method_return(msg, method.out_signature, body)) + body, fds = ServiceInterface._fn_result_to_body(result, + method.out_signature_tree) + send_reply(Message.new_method_return(msg, method.out_signature, body, fds)) fut = asyncio.ensure_future(method.fn(interface, *msg.body)) fut.add_done_callback(done) @@ -258,7 +297,7 @@ def _message_reader(self): while True: if self._unmarshaller.unmarshall(): self._on_message(self._unmarshaller.message) - self._unmarshaller = Unmarshaller(self._stream) + self._unmarshaller = Unmarshaller(self._stream, self._sock) else: break except Exception as e: diff --git a/dbus_next/aio/proxy_object.py b/dbus_next/aio/proxy_object.py index 7163199..4a9889d 100644 --- a/dbus_next/aio/proxy_object.py +++ b/dbus_next/aio/proxy_object.py @@ -1,6 +1,6 @@ from ..proxy_object import BaseProxyObject, BaseProxyInterface from ..message_bus import BaseMessageBus -from ..message import Message, MessageFlag +from ..message import Message, MessageFlag, _replace_fds from ..signature import Variant from ..errors import DBusError from ..constants import ErrorType @@ -82,6 +82,13 @@ async def method_fn(*args, flags=MessageFlag.NONE): BaseProxyInterface._check_method_return(msg, intr_method.out_signature) out_len = len(intr_method.out_args) + + def _replace(obj): + return msg.unix_fds[obj] + + if any(sig in msg.signature for sig in 'hv'): + _replace_fds(msg.body, msg.signature_tree.types, _replace) + if not out_len: return None elif out_len == 1: diff --git a/dbus_next/auth.py b/dbus_next/auth.py index 9f00643..b952ee2 100644 --- a/dbus_next/auth.py +++ b/dbus_next/auth.py @@ -53,6 +53,10 @@ class AuthExternal(Authenticator): :sealso: https://dbus.freedesktop.org/doc/dbus-specification.html#auth-protocol """ + def __init__(self, enable_fds=False): + self.enable_fds = enable_fds + self.negotiating_fds = False + def _authentication_start(self) -> str: hex_uid = str(os.getuid()).encode().hex() return f'AUTH EXTERNAL {hex_uid}' @@ -60,10 +64,21 @@ def _authentication_start(self) -> str: def _receive_line(self, line: str): response, args = _AuthResponse.parse(line) - if response != _AuthResponse.OK: - raise AuthError(f'authentication failed: {response.value}: {args}') + if response is _AuthResponse.OK: + if self.enable_fds: + self.NEGOTIATING_FDS = True + return "NEGOTIATE_UNIX_FD" + else: + return "BEGIN" - return 'BEGIN' + if response is _AuthResponse.AGREE_UNIX_FD: + return "BEGIN" + + if response is _AuthResponse.ERROR and self.NEGOTIATING_FDS: + # TODO: logger? + return "BEGIN" + + raise AuthError(f'authentication failed: {response.value}: {args}') class AuthAnnonymous(Authenticator): diff --git a/dbus_next/message.py b/dbus_next/message.py index 4dd3844..5c798f2 100644 --- a/dbus_next/message.py +++ b/dbus_next/message.py @@ -5,7 +5,7 @@ from .errors import InvalidMessageError from .signature import SignatureTree, Variant -from typing import List, Any +from typing import List, Any, Union, IO class Message: @@ -135,7 +135,10 @@ def new_error(msg: 'Message', error_name: str, error_text: str) -> 'Message': body=[error_text]) @staticmethod - def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = []) -> 'Message': + def new_method_return(msg: 'Message', + signature: str = '', + body: List[Any] = [], + fds: List[Union[int, IO]] = []) -> 'Message': """A convenience constructor to create a method return to the given method call message. :param msg: The method call message this is a reply to. @@ -144,6 +147,8 @@ def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = []) :type signature: str :param body: The body of this message. Must match the signature. :type body: list(Any) + :param fds: List of filelike objects or integers representing file descriptors. + :type body: list(file object or int) :returns: The method return message :rtype: :class:`Message` @@ -155,14 +160,16 @@ def new_method_return(msg: 'Message', signature: str = '', body: List[Any] = []) reply_serial=msg.serial, destination=msg.sender, signature=signature, - body=body) + body=body, + unix_fds=fds) @staticmethod def new_signal(path: str, interface: str, member: str, signature: str = '', - body: List[Any] = None) -> 'Message': + body: List[Any] = None, + fds: List[Any] = None) -> 'Message': """A convenience constructor to create a new signal message. :param path: The path of this signal. @@ -191,7 +198,8 @@ def new_signal(path: str, path=path, member=member, signature=signature, - body=body) + body=body, + unix_fds=fds) def _matches(self, **kwargs): for attr, val in kwargs.items(): @@ -221,8 +229,8 @@ def _marshall(self): fields.append([HeaderField.DESTINATION.value, Variant('s', self.destination)]) if self.signature: fields.append([HeaderField.SIGNATURE.value, Variant('g', self.signature)]) - for fd in self.unix_fds: - fields.append([HeaderField.UNIX_FDS.value, Variant('h', fd)]) + if self.unix_fds: + fields.append([HeaderField.UNIX_FDS.value, Variant('u', len(self.unix_fds))]) header_body = [ LITTLE_ENDIAN, self.message_type.value, self.flags.value, PROTOCOL_VERSION, @@ -232,3 +240,41 @@ def _marshall(self): header_block.marshall() header_block.align(8) return header_block.buffer + body_block.buffer + + +def _replace_fds(body_obj, children, replace_fn): + for index, st in enumerate(children): + if not any(sig in st.signature for sig in 'hv'): + continue + if st.signature == 'h': + body_obj[index] = replace_fn(body_obj[index]) + elif st.token == 'a': + if st.children[0].token == '{': + _replace_fds(body_obj[index], st.children, replace_fn) + else: + for i, child in enumerate(body_obj[index]): + if st.signature == 'ah': + body_obj[index][i] = replace_fn(child) + else: + _replace_fds([child], st.children, replace_fn) + elif st.token in '(': + _replace_fds(body_obj[index], st.children, replace_fn) + elif st.token in '{': + for key, value in list(body_obj.items()): + body_obj.pop(key) + if st.children[0].signature == 'h': + key = replace_fn(key) + if st.children[1].signature == 'h': + value = replace_fn(value) + else: + _replace_fds([value], [st.children[1]], replace_fn) + body_obj[key] = value + + elif st.signature == 'v': + if body_obj[index].signature == 'h': + body_obj[index].value = replace_fn(body_obj[index].value) + else: + _replace_fds([body_obj[index].value], [body_obj[index].type], replace_fn) + + elif st.children: + _replace_fds(body_obj[index], st.children, replace_fn) diff --git a/dbus_next/message_bus.py b/dbus_next/message_bus.py index 2a25f6b..f338995 100644 --- a/dbus_next/message_bus.py +++ b/dbus_next/message_bus.py @@ -454,7 +454,7 @@ def _has_interface(self, interface: ServiceInterface) -> bool: return False - def _interface_signal_notify(self, interface, interface_name, member, signature, body): + def _interface_signal_notify(self, interface, interface_name, member, signature, body, fds=[]): path = None for p, ifaces in self._path_exports.items(): for i in ifaces: @@ -469,7 +469,8 @@ def _interface_signal_notify(self, interface, interface_name, member, signature, interface=interface_name, member=member, signature=signature, - body=body)) + body=body, + fds=fds)) def _introspect_export_path(self, path): assert_object_path_valid(path) @@ -672,8 +673,9 @@ def _process_message(self, msg): def _make_method_handler(cls, interface, method): def handler(msg, send_reply): result = method.fn(interface, *msg.body) - body = ServiceInterface._fn_result_to_body(result, method.out_signature_tree) - send_reply(Message.new_method_return(msg, method.out_signature, body)) + body, fds = ServiceInterface._fn_result_to_body( + result, signature_tree=method.out_signature_tree) + send_reply(Message.new_method_return(msg, method.out_signature, body, fds)) return handler diff --git a/dbus_next/service.py b/dbus_next/service.py index bd390cd..d5ba309 100644 --- a/dbus_next/service.py +++ b/dbus_next/service.py @@ -2,10 +2,12 @@ from .signature import SignatureTree, SignatureBodyMismatchError, Variant from . import introspection as intr from .errors import SignalDisabledError +from .message import _replace_fds from functools import wraps import inspect from typing import no_type_check_decorator, Dict, List, Any +from os import dup # TODO: if the user uses `from __future__ import annotations` in their code, # the annotation inspection will not work because of PEP 563. We will get @@ -431,23 +433,42 @@ def _get_signals(interface): @staticmethod def _fn_result_to_body(result, signature_tree): + # TODO: find better way of passing results here instead of list out_len = len(signature_tree.types) if result is None: - body = [] - elif out_len == 0: - raise SignatureBodyMismatchError('Function was not expected to return an argument') - elif out_len == 1: - body = [result] - elif type(result) is not list: - raise SignatureBodyMismatchError('Expected function to return a list of arguments') + result = [] else: - body = result + if out_len == 1: + result = [result] + else: + if type(result) is not list: + raise SignatureBodyMismatchError( + 'Expected signal to return a list of arguments') - return body + if out_len != len(result): + raise SignatureBodyMismatchError( + "Signature and function return mismatch, expected %s arguments but got %s", + (len(signature_tree.types), len(result))) + + body = list(result) + fds = [] + + def _replace(obj): + if hasattr(obj, "fileno"): + obj = dup(obj.fileno()) + fds.append(obj) + + return len(fds) - 1 + + if any(sig in signature_tree.signature for sig in 'hv'): + _replace_fds(body, signature_tree.types, _replace) + + return body, fds @staticmethod def _handle_signal(interface, signal, result): - body = ServiceInterface._fn_result_to_body(result, signal.signature_tree) + body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree) for bus in ServiceInterface._get_buses(interface): + # TODO: can signal pass fds? bus._interface_signal_notify(interface, interface.name, signal.name, signal.signature, - body) + body, fds) diff --git a/test/test_fd_passing.py b/test/test_fd_passing.py new file mode 100644 index 0000000..9cd7a1c --- /dev/null +++ b/test/test_fd_passing.py @@ -0,0 +1,133 @@ +"""This tests the ability to send and receive file descriptors in dbus messages""" +from dbus_next.service import ServiceInterface, method +from dbus_next.signature import SignatureTree, Variant +from dbus_next.aio import MessageBus +from dbus_next import Message, MessageFlag +import socket + +import pytest + + +class ExampleInterface(ServiceInterface): + def __init__(self, name): + super().__init__(name) + self.fd = 0 + + @method() + async def echofd(self) -> 'h': + f = socket.socket() + return f + + +@pytest.mark.asyncio +async def test_sending_file_descriptor(): + bus1 = await MessageBus().connect() + bus2 = await MessageBus().connect() + + interface = ExampleInterface('test.interface') + export_path = '/test/path' + + async def call(member, signature='', body=[], flags=MessageFlag.NONE): + return await bus2.call( + Message(destination=bus1.unique_name, + path=export_path, + interface=interface.name, + member=member, + signature=signature, + body=body, + flags=flags)) + + bus1.export(export_path, interface) + + reply = await call('echofd') + + sock = socket.fromfd(reply.unix_fds[0], family=-1, type=-1) + assert sock + + +@pytest.mark.asyncio +async def test_sending_file_descriptor_with_proxy(): + name = 'dbus.next.test.service' + path = '/test/path' + interface_name = 'test.interface' + + bus = await MessageBus().connect() + interface = ExampleInterface(interface_name) + bus.export(path, interface) + await bus.request_name(name) + + intr = await bus.introspect(name, path) + + proxy = bus.get_proxy_object(name, path, intr) + proxy_interface = proxy.get_interface(interface_name) + reply = await proxy_interface.call_echofd() + sock = socket.fromfd(reply, family=-1, type=-1) + assert sock + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "result, out_signature, expected", + [ + pytest.param(5, 'h', ([0], [5]), id='Signature: "h"'), + pytest.param([5, "foo"], 'hs', ([0, "foo"], [5]), id='Signature: "hs"'), + pytest.param([5, 7], 'hh', ([0, 1], [5, 7]), id='Signature: "hh"'), + pytest.param([5, 7], 'ah', ([[0, 1]], [5, 7]), id='Signature: "ah"'), + pytest.param([9], 'ah', ([[0]], [9]), id='Signature: "ah"'), + pytest.param([3], '(h)', ([[0]], [3]), id='Signature: "(h)"'), + pytest.param([3, "foo"], '(hs)', ([[0, "foo"]], [3]), id='Signature: "(hs)"'), + pytest.param([[7, "foo"], [8, "bar"]], + 'a(hs)', ([[[0, "foo"], [1, "bar"]]], [7, 8]), + id='Signature: "a(hs)"'), + pytest.param({"foo": 3}, 'a{sh}', ([{ + "foo": 0 + }], [3]), id='Signature: "a{sh}"'), + pytest.param({ + "foo": 3, + "bar": 6 + }, + 'a{sh}', ([{ + "foo": 0, + "bar": 1 + }], [3, 6]), + id='Signature: "a{sh}"'), + pytest.param( + {"foo": [3, 8]}, 'a{sah}', ([{ + "foo": [0, 1] + }], [3, 8]), id='Signature: "a{sah}"'), + pytest.param({'foo': Variant('t', 100)}, + 'a{sv}', ([{ + 'foo': Variant('t', 100) + }], []), + id='Signature: "a{sv}"'), + pytest.param(['one', ['two', [Variant('s', 'three')]]], + '(s(s(v)))', ([['one', ['two', [Variant('s', 'three')]]]], []), + id='Signature: "(s(s(v)))"'), + pytest.param(Variant('h', 2), 'v', ([Variant('h', 0)], [2]), id='Variant with: "h"'), + pytest.param(Variant('(hh)', [2, 8]), + 'v', ([Variant('(hh)', [0, 1])], [2, 8]), + id='Variant with: "(hh)"'), + pytest.param( + Variant('ah', [2, 4]), 'v', ([Variant('ah', [0, 1])], [2, 4]), id='Variant with: "ah"'), + pytest.param(Variant('(ss)', ['hello', 'world']), + 'v', ([Variant('(ss)', ['hello', 'world'])], []), + id='Variant with: "(ss)"'), + pytest.param(Variant('v', Variant('t', 100)), + 'v', ([Variant('v', Variant('t', 100))], []), + id='Variant with: "v"'), + pytest.param([ + Variant('v', Variant('(ss)', ['hello', 'world'])), { + 'foo': Variant('t', 100) + }, ['one', ['two', [Variant('s', 'three')]]] + ], + 'va{sv}(s(s(v)))', ([ + Variant('v', Variant('(ss)', ['hello', 'world'])), { + 'foo': Variant('t', 100) + }, ['one', ['two', [Variant('s', 'three')]]] + ], []), + id='Variant with: "va{sv}(s(s(v)))"'), + ], +) +async def test_fn_result_to_body(result, out_signature, expected): + out_signature_tree = SignatureTree(out_signature) + assert ServiceInterface._fn_result_to_body(result, out_signature_tree) == expected