Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions packages/polywrap-msgpack/polywrap_msgpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,39 @@
custom extension types defined by wrap standard
"""
from enum import Enum
from typing import Any, Dict, List, Set, cast
from typing import Any, Dict, List, Set, Tuple, cast

import msgpack
from msgpack.ext import ExtType
from msgpack.exceptions import UnpackValueError

from .generic_map import GenericMap


class ExtensionTypes(Enum):
"""Wrap msgpack extension types."""

GENERIC_MAP = 1


def ext_hook(code: int, data: bytes) -> Any:
def encode_ext_hook(obj: Any) -> ExtType:
"""Extension hook for extending the msgpack supported types.

Args:
obj (Any): object to be encoded

Raises:
TypeError: when given object is not supported

Returns:
Tuple[int, bytes]: extension type code and payload
"""
if isinstance(obj, GenericMap):
return ExtType(ExtensionTypes.GENERIC_MAP.value, msgpack_encode(obj._map)) # type: ignore
raise TypeError(f"Object of type {type(obj)} is not supported")


def decode_ext_hook(code: int, data: bytes) -> Any:
"""Extension hook for extending the msgpack supported types.

Args:
Expand All @@ -34,7 +54,7 @@ def ext_hook(code: int, data: bytes) -> Any:
Any: decoded object
"""
if code == ExtensionTypes.GENERIC_MAP.value:
return msgpack_decode(data)
return GenericMap(msgpack_decode(data))
raise UnpackValueError("Invalid Extention type")


Expand All @@ -50,6 +70,8 @@ def sanitize(value: Any) -> Any:
Returns:
Any: msgpack compatible sanitized value
"""
if isinstance(value, GenericMap):
return cast(Any, value)
if isinstance(value, dict):
dictionary: Dict[Any, Any] = value
for key, val in dictionary.items():
Expand All @@ -59,7 +81,7 @@ def sanitize(value: Any) -> Any:
array: List[Any] = value
return [sanitize(a) for a in array]
if isinstance(value, tuple):
array: List[Any] = list(value) # type: ignore partially unknown
array: List[Any] = list(cast(Tuple[Any], value))
return sanitize(array)
if isinstance(value, set):
set_val: Set[Any] = value
Expand Down Expand Up @@ -87,7 +109,7 @@ def msgpack_encode(value: Any) -> bytes:
bytes: encoded msgpack value
"""
sanitized = sanitize(value)
return msgpack.packb(sanitized)
return msgpack.packb(sanitized, default=encode_ext_hook, use_bin_type=True)


def msgpack_decode(val: bytes) -> Any:
Expand All @@ -99,4 +121,4 @@ def msgpack_decode(val: bytes) -> Any:
Returns:
Any: python object
"""
return msgpack.unpackb(val, ext_hook=ext_hook)
return msgpack.unpackb(val, ext_hook=decode_ext_hook)
29 changes: 29 additions & 0 deletions packages/polywrap-msgpack/polywrap_msgpack/generic_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, MutableMapping, TypeVar

K = TypeVar("K")
V = TypeVar("V")


class GenericMap(MutableMapping[K, V]):
_map: Dict[K, V]

def __init__(self, map: Dict[K, V]):
self._map = dict(map)

def __getitem__(self, key: K) -> V:
return self._map[key]

def __setitem__(self, key: K, value: V) -> None:
self._map[key] = value

def __delitem__(self, key: K) -> None:
del self._map[key]

def __iter__(self):
return iter(self._map)

def __len__(self) -> int:
return len(self._map)

def __repr__(self) -> str:
return f"GenericMap({repr(self._map)})"
13 changes: 7 additions & 6 deletions packages/polywrap-msgpack/tests/test_msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Set, Tuple

from polywrap_msgpack import msgpack_decode, msgpack_encode, sanitize
from polywrap_msgpack.generic_map import GenericMap
from tests.conftest import DataClassObject, DataClassObjectWithSlots, Example

# ENCODING AND DECODING
Expand Down Expand Up @@ -147,16 +148,10 @@ def test_sanitize_set_returns_list_with_all_items_of_the_set(
set1: Set[Any], set2: Set[Any]
):
sanitized = sanitize(set1)
# r: List[bool] = []
assert list(set1) == sanitized
# [r.append(True) if item in sanitized else r.append(False) for item in set1]
# assert False not in r

sanitized = sanitize(set2)
assert list(set2) == sanitized
# r = []
# [r.append(True) if item in sanitized else r.append(False) for item in set2]
# assert False not in r


def test_sanitize_set_returns_list_of_same_length(set1: Set[Any]):
Expand Down Expand Up @@ -261,3 +256,9 @@ def test_sanitize_dict_of_dataclass_objects_with_slots_returns_list_of_dicts(
"firstKey": dataclass_object_with_slots1_sanitized,
"secondKey": dataclass_object_with_slots2_sanitized,
}


def test_encode_generic_map():
generic_map = GenericMap({"firstKey": "firstValue", "secondKey": "secondValue"})
assert sanitize(generic_map) == generic_map
assert msgpack_decode(msgpack_encode(generic_map)) == generic_map
4 changes: 2 additions & 2 deletions packages/polywrap-msgpack/typings/msgpack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def pack(o: Any, stream: IOBase, **kwargs: Dict[Any, Any]) -> IOBase: # -> None
"""
...

def packb(o: Any, **kwargs: Dict[Any, Any]) -> bytes: # -> None:
def packb(o: Any, default: Optional[Callable[[Any], ExtType]], use_bin_type: bool, **kwargs: Dict[Any, Any]) -> bytes: # -> None:
"""
Pack object `o` and return packed bytes

See :class:`Packer` for options.
"""
...

def unpack(stream: IOBase, **kwargs: Dict[Any, Any]) -> Any:
def unpack(stream: IOBase,**kwargs: Dict[Any, Any]) -> Any:
"""
Unpack an object from `stream`.

Expand Down
2 changes: 1 addition & 1 deletion packages/polywrap-plugin/polywrap_plugin/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def invoke(

args: Union[Dict[str, Any], bytes] = options.args or {}
decoded_args: Dict[str, Any] = (
msgpack_decode(args) if isinstance(args, bytes) else args
msgpack_decode(args) if isinstance(args, (bytes, bytearray)) else args
)

result = cast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,9 @@ async def try_resolve_uri_with_resolvers(

for resolver in resolvers:
result = await resolver.try_resolve_uri(uri, client, sub_context)

if result.is_ok():
return result

if not result.is_ok():
step = IUriResolutionStep(
source_uri=uri,
result=result,
sub_history=sub_context.get_history(),
description=self.get_step_description(),
)
resolution_context.track_step(step)

if result.is_ok() and not (
isinstance(result.unwrap(), Uri) and result.unwrap() == uri
):
return result

result = Ok(uri)
Expand Down
6 changes: 4 additions & 2 deletions packages/polywrap-wasm/polywrap_wasm/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def wrap_subinvoke(
result = unfuture_result.result()

if result.is_ok():
result = cast(Ok[bytes], result)
state.subinvoke["result"] = result.unwrap()
if isinstance(result.unwrap(), (bytes, bytearray)):
state.subinvoke["result"] = result.unwrap()
return True
state.subinvoke["result"] = msgpack_encode(result.unwrap())
return True
elif result.is_err():
error = cast(Err, result).unwrap_err()
Expand Down