forked from jekirl/poketrainer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprotobuf_to_dict.py
More file actions
177 lines (138 loc) · 6.57 KB
/
protobuf_to_dict.py
File metadata and controls
177 lines (138 loc) · 6.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import base64
import six
from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import Message
__all__ = ["protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf",
"REVERSE_TYPE_CALLABLE_MAP"]
EXTENSION_CONTAINER = '___X'
TYPE_CALLABLE_MAP = {
FieldDescriptor.TYPE_DOUBLE: float,
FieldDescriptor.TYPE_FLOAT: float,
FieldDescriptor.TYPE_INT32: int,
FieldDescriptor.TYPE_INT64: int if six.PY3 else six.integer_types[1],
FieldDescriptor.TYPE_UINT32: int,
FieldDescriptor.TYPE_UINT64: int if six.PY3 else six.integer_types[1],
FieldDescriptor.TYPE_SINT32: int,
FieldDescriptor.TYPE_SINT64: int if six.PY3 else six.integer_types[1],
FieldDescriptor.TYPE_FIXED32: int,
FieldDescriptor.TYPE_FIXED64: int if six.PY3 else six.integer_types[1],
FieldDescriptor.TYPE_SFIXED32: int,
FieldDescriptor.TYPE_SFIXED64: int if six.PY3 else six.integer_types[1],
FieldDescriptor.TYPE_BOOL: bool,
FieldDescriptor.TYPE_STRING: six.text_type,
FieldDescriptor.TYPE_BYTES: lambda b: base64.b64encode(b),
FieldDescriptor.TYPE_ENUM: int,
}
def repeated(type_callable):
return lambda value_list: [type_callable(value) for value in value_list]
def enum_label_name(field, value):
return field.enum_type.values_by_number[int(value)].name
def protobuf_to_dict(pb, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False):
result_dict = {}
extensions = {}
for field, value in pb.ListFields():
if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry:
result_dict[field.name] = dict(value)
continue
type_callable = _get_field_value_adaptor(pb, field, type_callable_map, use_enum_labels)
if field.label == FieldDescriptor.LABEL_REPEATED:
type_callable = repeated(type_callable)
if field.is_extension:
extensions[str(field.number)] = type_callable(value)
continue
result_dict[field.name] = type_callable(value)
if extensions:
result_dict[EXTENSION_CONTAINER] = extensions
return result_dict
def _get_field_value_adaptor(pb, field, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False):
if field.type == FieldDescriptor.TYPE_MESSAGE:
# recursively encode protobuf sub-message
return lambda pb: protobuf_to_dict(
pb, type_callable_map=type_callable_map,
use_enum_labels=use_enum_labels)
if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM:
return lambda value: enum_label_name(field, value)
if field.type in type_callable_map:
return type_callable_map[field.type]
raise TypeError("Field %s.%s has unrecognised type id %d" % (
pb.__class__.__name__, field.name, field.type))
def get_bytes(value):
return base64.b64decode(value)
REVERSE_TYPE_CALLABLE_MAP = {
FieldDescriptor.TYPE_BYTES: get_bytes,
}
def dict_to_protobuf(pb_klass_or_instance, values, type_callable_map=REVERSE_TYPE_CALLABLE_MAP, strict=True):
"""Populates a protobuf model from a dictionary.
:param pb_klass_or_instance: a protobuf message class, or an protobuf instance
:type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message
:param dict values: a dictionary of values. Repeated and nested values are
fully supported.
:param dict type_callable_map: a mapping of protobuf types to callables for setting
values on the target instance.
:param bool strict: complain if keys in the map are not fields on the message.
"""
if isinstance(pb_klass_or_instance, Message):
instance = pb_klass_or_instance
else:
instance = pb_klass_or_instance()
return _dict_to_protobuf(instance, values, type_callable_map, strict)
def _get_field_mapping(pb, dict_value, strict):
field_mapping = []
for key, value in dict_value.items():
if key == EXTENSION_CONTAINER:
continue
if key not in pb.DESCRIPTOR.fields_by_name:
if strict:
raise KeyError("%s does not have a field called %s" % (pb, key))
continue
field_mapping.append((pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None)))
for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items():
try:
ext_num = int(ext_num)
except ValueError:
raise ValueError("Extension keys must be integers.")
if ext_num not in pb._extensions_by_number:
if strict:
raise KeyError(
"%s does not have a extension with number %s. Perhaps you forgot to import it?" % (pb, key))
continue
ext_field = pb._extensions_by_number[ext_num]
pb_val = None
pb_val = pb.Extensions[ext_field]
field_mapping.append((ext_field, ext_val, pb_val))
return field_mapping
def _dict_to_protobuf(pb, value, type_callable_map, strict):
fields = _get_field_mapping(pb, value, strict)
for field, input_value, pb_value in fields:
if field.label == FieldDescriptor.LABEL_REPEATED:
if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry:
pb_value.update(input_value)
continue
for item in input_value:
if field.type == FieldDescriptor.TYPE_MESSAGE:
m = pb_value.add()
_dict_to_protobuf(m, item, type_callable_map, strict)
elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(item, six.string_types):
pb_value.append(_string_to_enum(field, item))
else:
pb_value.append(item)
continue
if field.type == FieldDescriptor.TYPE_MESSAGE:
_dict_to_protobuf(pb_value, input_value, type_callable_map, strict)
continue
if field.type in type_callable_map:
input_value = type_callable_map[field.type](input_value)
if field.is_extension:
pb.Extensions[field] = input_value
continue
if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value, six.string_types):
input_value = _string_to_enum(field, input_value)
setattr(pb, field.name, input_value)
return pb
def _string_to_enum(field, input_value):
enum_dict = field.enum_type.values_by_name
try:
input_value = enum_dict[input_value].number
except KeyError:
raise KeyError("`%s` is not a valid value for field `%s`" % (input_value, field.name))
return input_value