From 14fd99e2b16883937cbbd24259092669c7263aa9 Mon Sep 17 00:00:00 2001 From: Niraj Kamdar Date: Thu, 2 Mar 2023 23:39:10 +0400 Subject: [PATCH] fix: msgpack bug fix --- .../polywrap_msgpack/__init__.py | 34 +++++++++++++++---- .../polywrap_msgpack/generic_map.py | 29 ++++++++++++++++ .../polywrap-msgpack/tests/test_msgpack.py | 13 +++---- .../typings/msgpack/__init__.pyi | 4 +-- .../polywrap_plugin/wrapper.py | 2 +- .../abc/uri_resolver_aggregator.py | 16 ++------- .../polywrap-wasm/polywrap_wasm/imports.py | 6 ++-- 7 files changed, 74 insertions(+), 30 deletions(-) create mode 100644 packages/polywrap-msgpack/polywrap_msgpack/generic_map.py diff --git a/packages/polywrap-msgpack/polywrap_msgpack/__init__.py b/packages/polywrap-msgpack/polywrap_msgpack/__init__.py index fd990259..15eed0b2 100644 --- a/packages/polywrap-msgpack/polywrap_msgpack/__init__.py +++ b/packages/polywrap-msgpack/polywrap_msgpack/__init__.py @@ -8,11 +8,14 @@ custom extension types defined by wrap standard """ from enum import Enum -from typing import Any, Dict, List, Set, cast +from typing import Any, Dict, List, Set, Tuple, cast import msgpack +from msgpack.ext import ExtType from msgpack.exceptions import UnpackValueError +from .generic_map import GenericMap + class ExtensionTypes(Enum): """Wrap msgpack extension types.""" @@ -20,7 +23,24 @@ class ExtensionTypes(Enum): GENERIC_MAP = 1 -def ext_hook(code: int, data: bytes) -> Any: +def encode_ext_hook(obj: Any) -> ExtType: + """Extension hook for extending the msgpack supported types. + + Args: + obj (Any): object to be encoded + + Raises: + TypeError: when given object is not supported + + Returns: + Tuple[int, bytes]: extension type code and payload + """ + if isinstance(obj, GenericMap): + return ExtType(ExtensionTypes.GENERIC_MAP.value, msgpack_encode(obj._map)) # type: ignore + raise TypeError(f"Object of type {type(obj)} is not supported") + + +def decode_ext_hook(code: int, data: bytes) -> Any: """Extension hook for extending the msgpack supported types. Args: @@ -34,7 +54,7 @@ def ext_hook(code: int, data: bytes) -> Any: Any: decoded object """ if code == ExtensionTypes.GENERIC_MAP.value: - return msgpack_decode(data) + return GenericMap(msgpack_decode(data)) raise UnpackValueError("Invalid Extention type") @@ -50,6 +70,8 @@ def sanitize(value: Any) -> Any: Returns: Any: msgpack compatible sanitized value """ + if isinstance(value, GenericMap): + return cast(Any, value) if isinstance(value, dict): dictionary: Dict[Any, Any] = value for key, val in dictionary.items(): @@ -59,7 +81,7 @@ def sanitize(value: Any) -> Any: array: List[Any] = value return [sanitize(a) for a in array] if isinstance(value, tuple): - array: List[Any] = list(value) # type: ignore partially unknown + array: List[Any] = list(cast(Tuple[Any], value)) return sanitize(array) if isinstance(value, set): set_val: Set[Any] = value @@ -87,7 +109,7 @@ def msgpack_encode(value: Any) -> bytes: bytes: encoded msgpack value """ sanitized = sanitize(value) - return msgpack.packb(sanitized) + return msgpack.packb(sanitized, default=encode_ext_hook, use_bin_type=True) def msgpack_decode(val: bytes) -> Any: @@ -99,4 +121,4 @@ def msgpack_decode(val: bytes) -> Any: Returns: Any: python object """ - return msgpack.unpackb(val, ext_hook=ext_hook) + return msgpack.unpackb(val, ext_hook=decode_ext_hook) diff --git a/packages/polywrap-msgpack/polywrap_msgpack/generic_map.py b/packages/polywrap-msgpack/polywrap_msgpack/generic_map.py new file mode 100644 index 00000000..a8f5ee20 --- /dev/null +++ b/packages/polywrap-msgpack/polywrap_msgpack/generic_map.py @@ -0,0 +1,29 @@ +from typing import Dict, MutableMapping, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class GenericMap(MutableMapping[K, V]): + _map: Dict[K, V] + + def __init__(self, map: Dict[K, V]): + self._map = dict(map) + + def __getitem__(self, key: K) -> V: + return self._map[key] + + def __setitem__(self, key: K, value: V) -> None: + self._map[key] = value + + def __delitem__(self, key: K) -> None: + del self._map[key] + + def __iter__(self): + return iter(self._map) + + def __len__(self) -> int: + return len(self._map) + + def __repr__(self) -> str: + return f"GenericMap({repr(self._map)})" diff --git a/packages/polywrap-msgpack/tests/test_msgpack.py b/packages/polywrap-msgpack/tests/test_msgpack.py index e4525f67..2a3a6ca0 100644 --- a/packages/polywrap-msgpack/tests/test_msgpack.py +++ b/packages/polywrap-msgpack/tests/test_msgpack.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Set, Tuple from polywrap_msgpack import msgpack_decode, msgpack_encode, sanitize +from polywrap_msgpack.generic_map import GenericMap from tests.conftest import DataClassObject, DataClassObjectWithSlots, Example # ENCODING AND DECODING @@ -147,16 +148,10 @@ def test_sanitize_set_returns_list_with_all_items_of_the_set( set1: Set[Any], set2: Set[Any] ): sanitized = sanitize(set1) - # r: List[bool] = [] assert list(set1) == sanitized - # [r.append(True) if item in sanitized else r.append(False) for item in set1] - # assert False not in r sanitized = sanitize(set2) assert list(set2) == sanitized - # r = [] - # [r.append(True) if item in sanitized else r.append(False) for item in set2] - # assert False not in r def test_sanitize_set_returns_list_of_same_length(set1: Set[Any]): @@ -261,3 +256,9 @@ def test_sanitize_dict_of_dataclass_objects_with_slots_returns_list_of_dicts( "firstKey": dataclass_object_with_slots1_sanitized, "secondKey": dataclass_object_with_slots2_sanitized, } + + +def test_encode_generic_map(): + generic_map = GenericMap({"firstKey": "firstValue", "secondKey": "secondValue"}) + assert sanitize(generic_map) == generic_map + assert msgpack_decode(msgpack_encode(generic_map)) == generic_map diff --git a/packages/polywrap-msgpack/typings/msgpack/__init__.pyi b/packages/polywrap-msgpack/typings/msgpack/__init__.pyi index fdc39f59..5edb76b4 100644 --- a/packages/polywrap-msgpack/typings/msgpack/__init__.pyi +++ b/packages/polywrap-msgpack/typings/msgpack/__init__.pyi @@ -24,7 +24,7 @@ def pack(o: Any, stream: IOBase, **kwargs: Dict[Any, Any]) -> IOBase: # -> None """ ... -def packb(o: Any, **kwargs: Dict[Any, Any]) -> bytes: # -> None: +def packb(o: Any, default: Optional[Callable[[Any], ExtType]], use_bin_type: bool, **kwargs: Dict[Any, Any]) -> bytes: # -> None: """ Pack object `o` and return packed bytes @@ -32,7 +32,7 @@ def packb(o: Any, **kwargs: Dict[Any, Any]) -> bytes: # -> None: """ ... -def unpack(stream: IOBase, **kwargs: Dict[Any, Any]) -> Any: +def unpack(stream: IOBase,**kwargs: Dict[Any, Any]) -> Any: """ Unpack an object from `stream`. diff --git a/packages/polywrap-plugin/polywrap_plugin/wrapper.py b/packages/polywrap-plugin/polywrap_plugin/wrapper.py index 247754fc..1285ae42 100644 --- a/packages/polywrap-plugin/polywrap_plugin/wrapper.py +++ b/packages/polywrap-plugin/polywrap_plugin/wrapper.py @@ -35,7 +35,7 @@ async def invoke( args: Union[Dict[str, Any], bytes] = options.args or {} decoded_args: Dict[str, Any] = ( - msgpack_decode(args) if isinstance(args, bytes) else args + msgpack_decode(args) if isinstance(args, (bytes, bytearray)) else args ) result = cast( diff --git a/packages/polywrap-uri-resolvers/polywrap_uri_resolvers/abc/uri_resolver_aggregator.py b/packages/polywrap-uri-resolvers/polywrap_uri_resolvers/abc/uri_resolver_aggregator.py index de312392..74fdb531 100644 --- a/packages/polywrap-uri-resolvers/polywrap_uri_resolvers/abc/uri_resolver_aggregator.py +++ b/packages/polywrap-uri-resolvers/polywrap_uri_resolvers/abc/uri_resolver_aggregator.py @@ -46,19 +46,9 @@ async def try_resolve_uri_with_resolvers( for resolver in resolvers: result = await resolver.try_resolve_uri(uri, client, sub_context) - - if result.is_ok(): - return result - - if not result.is_ok(): - step = IUriResolutionStep( - source_uri=uri, - result=result, - sub_history=sub_context.get_history(), - description=self.get_step_description(), - ) - resolution_context.track_step(step) - + if result.is_ok() and not ( + isinstance(result.unwrap(), Uri) and result.unwrap() == uri + ): return result result = Ok(uri) diff --git a/packages/polywrap-wasm/polywrap_wasm/imports.py b/packages/polywrap-wasm/polywrap_wasm/imports.py index 4b5e163e..7f65c98c 100644 --- a/packages/polywrap-wasm/polywrap_wasm/imports.py +++ b/packages/polywrap-wasm/polywrap_wasm/imports.py @@ -200,8 +200,10 @@ def wrap_subinvoke( result = unfuture_result.result() if result.is_ok(): - result = cast(Ok[bytes], result) - state.subinvoke["result"] = result.unwrap() + if isinstance(result.unwrap(), (bytes, bytearray)): + state.subinvoke["result"] = result.unwrap() + return True + state.subinvoke["result"] = msgpack_encode(result.unwrap()) return True elif result.is_err(): error = cast(Err, result).unwrap_err()