Skip to content

Commit 26fd824

Browse files
authored
Update yacs.py
1 parent 10a4b6a commit 26fd824

1 file changed

Lines changed: 158 additions & 56 deletions

File tree

utils/yacs.py

Lines changed: 158 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# Author: Qiu Jueqin (qiujueqin@gmail.com)
55

66

7+
import os
78
import os.path as op
8-
import copy
99
import argparse
1010
import pathlib
11+
from copy import deepcopy
1112
from collections import OrderedDict, defaultdict
1213
from contextlib import contextmanager
1314

@@ -55,7 +56,7 @@
5556

5657
class Config(OrderedDict):
5758

58-
def __init__(self, init=None):
59+
def __init__(self, init=None, **kwargs):
5960
"""
6061
:param init: dict | yaml filepath | argparse.Namespace
6162
"""
@@ -69,7 +70,7 @@ def __init__(self, init=None):
6970
elif isinstance(init, str):
7071
self.from_yaml(init)
7172
elif isinstance(init, argparse.Namespace):
72-
self.from_namespace(init)
73+
self.from_namespace(init, **kwargs)
7374
else:
7475
raise TypeError(
7576
f'Config could only be instantiated from a dict, a yaml '
@@ -106,10 +107,17 @@ def unfreeze(self):
106107
def _set_immutable(self, is_immutable):
107108
""" Recursively set immutability. """
108109

109-
self.__dict__['__immutable__'] = is_immutable
110-
for v in self.values():
111-
if isinstance(v, Config):
112-
v._set_immutable(is_immutable)
110+
def _recursively_set_immutable(obj):
111+
if isinstance(obj, dict):
112+
if isinstance(obj, Config):
113+
obj.__dict__['__immutable__'] = is_immutable
114+
for v in obj.values():
115+
_recursively_set_immutable(v)
116+
elif isinstance(obj, (list, tuple)):
117+
for item in obj:
118+
_recursively_set_immutable(item)
119+
120+
_recursively_set_immutable(self)
113121

114122
# ---------------- Set & Get ----------------
115123

@@ -163,7 +171,7 @@ def from_yaml(self, yaml_path):
163171
super().__init__(Config._from_dict(dic))
164172
self.freeze()
165173

166-
def from_namespace(self, namespace):
174+
def from_namespace(self, parsed_args, unknown_args=None):
167175
"""
168176
Instantiation from an argparse.Namespace object.
169177
@@ -181,40 +189,77 @@ def from_namespace(self, namespace):
181189
182190
Given the returned argparse.Namespace object 'args', from_namespace()
183191
will create a Config object as if it was instantiated from a nested
184-
dict d = {'foo': {'bar': 42}}
192+
dict d = {'foo': {'bar': 42}}.
193+
194+
Optionally, the extra argument `unknown_args` also accepts unknown
195+
arguments by parser.parse_known_args(), but note that the arguments
196+
in command line must starts with '--'.
197+
198+
For example, creating an argparse.ArgumentParser with '--foo'
199+
argument:
200+
201+
>>> parser = argparse.ArgumentParser()
202+
>>> parser.add_argument('--foo', type=int, default=0)
203+
204+
but in command line user also inputs other arguments:
205+
206+
```
207+
python main.py --foo 42 --bar ['Alice', 'Bob']
208+
```
209+
210+
Given the returned argparse.Namespace object `parsed` and unknown
211+
args list `unknown` by calling
212+
213+
>>> parsed, unknown = parser.parse_known_args()
214+
215+
`from_namespace(parsed, unknown_args=unknown)` will create a Config
216+
object as if it was instantiated from a dict
217+
d = {'foo': 42, 'bar': ['Alice', 'Bob']}.
185218
"""
186219

187-
if not isinstance(namespace, argparse.Namespace):
220+
if not isinstance(parsed_args, argparse.Namespace):
188221
raise TypeError(
189222
f'expected an argparse.Namespace object, but given a '
190-
f'{type(namespace)} '
223+
f'{type(parsed_args)} '
224+
)
225+
226+
nested_dict = self._separator_dict_to_nested_dict(vars(parsed_args))
227+
228+
if unknown_args:
229+
nested_dict.update(
230+
self._separator_dict_to_nested_dict(self._unknown_args_to_dict(unknown_args))
191231
)
192232

193-
nested_dict = self._separator_dict_to_nested_dict(vars(namespace))
194233
super().__init__(Config._from_dict(nested_dict))
195234
self.freeze()
196235

197-
def merge(self, other, allow_new_attr=False, keep_existed_attr=True):
236+
def merge(self, other,
237+
exclusive=True,
238+
max_exclusive_depth=float('Inf'),
239+
keep_existed_attr=True):
198240
"""
199241
Recursively merge from other object
200242
201243
:param other: Config object | dict | yaml filepath |
202244
argparse.Namespace object
203-
:param allow_new_attr: whether allow to add new attributes
245+
:param exclusive: if set to True, merging with new fields is forbidden
204246
205247
Example:
206248
207249
>>> cfg = Config({'optimizer': 'adam'})
208-
>>> cfg.merge({'lr': 0.001}, allow_new_attr=True)
250+
>>> cfg.merge({'lr': 0.001}, exclusive=False)
209251
>>> cfg.print()
210252
211253
optimizer: adam
212254
lr: 0.001
213255
214-
>>> cfg.merge({'weight_decay': 1E-7}, allow_new_attr=False)
256+
>>> cfg.merge({'weight_decay': 1E-7}, exclusive=True)
215257
216258
AttributeError: attempted to add a new attribute: weight_decay
217259
260+
:param max_exclusive_depth: max depth to prevent from merging new
261+
attributes, only valid when exclusive=True. Set to 0 is equal to
262+
exclusive=False
218263
:param keep_existed_attr: whether keep those attributes that are not
219264
in 'other'. You may wish to trigger this if requires to completely
220265
replace a child Config object. See example/examples.py: Example 5
@@ -225,15 +270,15 @@ def merge(self, other, allow_new_attr=False, keep_existed_attr=True):
225270
>>> cfg1 = Config({'foo': {'Alice': 0, 'Bob': 1}})
226271
>>> cfg2 = cfg1.copy()
227272
>>> another = {'foo': {'Carol': 42}}
228-
>>> cfg1.merge(another, allow_new_attr=True)
273+
>>> cfg1.merge(another, exclusive=False)
229274
>>> cfg1.print()
230275
231276
foo:
232277
Alice: 0
233278
Bob: 1
234279
Carol: 42
235280
236-
>>> cfg2.merge(another, allow_new_attr=True, keep_existed_attr=False)
281+
>>> cfg2.merge(another, exclusive=False, keep_existed_attr=False)
237282
>>> cfg2.print()
238283
239284
foo:
@@ -249,25 +294,24 @@ def merge(self, other, allow_new_attr=False, keep_existed_attr=True):
249294
f'attempted to merge from an unsupported {type(other)} object'
250295
)
251296

252-
def _merge(source_cfg, other_cfg, add_new, keep_existed):
297+
def _merge(source_cfg, other_cfg, excl, keep_existed, _cur_depth=1):
253298
""" Recursively merge the new Config object into the source one """
254-
255-
with source_cfg.unfreeze():
299+
with source_cfg.unfreeze(), other_cfg.unfreeze():
256300
for k, v in other_cfg.items():
257-
if k not in source_cfg and not add_new:
301+
if k not in source_cfg and excl and _cur_depth <= max_exclusive_depth:
258302
raise AttributeError(
259-
f'attempted to add an attribute {k} but it is not '
260-
f'found in the source Config. Set `allow_new_attr` '
261-
f'to True if requires to add new attributes'
303+
f'attempted to merge an attribute `{k}` that is not '
304+
f'found in the source Config. Set `exclusive` to False '
305+
f'if requires to add new attributes'
262306
)
263307

264308
if isinstance(v, Config):
265-
if isinstance(source_cfg.get(k, None), Config):
266-
_merge(source_cfg[k], v, add_new, keep_existed)
309+
if isinstance(source_cfg.get(k), Config):
310+
_merge(source_cfg[k], v, excl, keep_existed, _cur_depth=_cur_depth + 1)
267311
else:
268312
source_cfg[k] = v
269313
else:
270-
source_cfg[k] = copy.deepcopy(v)
314+
source_cfg[k] = deepcopy(v)
271315

272316
if not keep_existed:
273317
source_keys = list(source_cfg.keys())
@@ -276,7 +320,7 @@ def _merge(source_cfg, other_cfg, add_new, keep_existed):
276320
not isinstance(source_cfg[k], Config):
277321
source_cfg.remove(k)
278322

279-
_merge(self, other, allow_new_attr, keep_existed_attr)
323+
_merge(self, other, exclusive, keep_existed_attr)
280324

281325
# ---------------- Output ----------------
282326

@@ -286,16 +330,20 @@ def to_dict(self, alphabetical=False):
286330
An inverse method to self.from_dict()
287331
"""
288332

289-
def _to_dict(config):
290-
if alphabetical:
291-
config = sorted(config.items(), key=lambda x: x[0])
292-
dic = dict(config)
293-
for k, v in dic.items():
294-
if isinstance(v, Config):
295-
dic[k] = _to_dict(v)
296-
return dic
333+
def _recursively_to_dict(obj):
334+
if isinstance(obj, Config):
335+
if alphabetical:
336+
obj = sorted(obj.items(), key=lambda x: x[0])
337+
dic = dict(obj)
338+
for k, v in dic.items():
339+
dic[k] = _recursively_to_dict(v)
340+
return dic
341+
elif isinstance(obj, (list, tuple)):
342+
return tuple(_recursively_to_dict(item) for item in obj)
343+
else:
344+
return obj
297345

298-
return _to_dict(self)
346+
return _recursively_to_dict(self)
299347

300348
def to_parser(self):
301349
"""
@@ -329,38 +377,68 @@ def to_parser(self):
329377

330378
return parser
331379

332-
def dump(self, save_path):
380+
def dump(self, save_path, ignored_keys=()):
333381
""" Dump a Config object into a yaml file """
382+
if not save_path.endswith('.yaml'):
383+
raise TypeError('only yaml file is supported by dump() method')
384+
385+
def _serialize(obj):
386+
serializable_types = (bool, str, int, float, list, tuple, dict, set, type(None))
387+
if not isinstance(obj, serializable_types):
388+
return '{} <class \'{}\'>'.format(str(obj), obj.__class__.__name__)
389+
elif isinstance(obj, dict):
390+
return {k: _serialize(v) for k, v in obj.items()}
391+
elif isinstance(obj, (list, tuple)):
392+
return [_serialize(item) for item in obj]
393+
else:
394+
return obj
395+
396+
serializable_dic = _serialize({
397+
k: v for k, v in self.to_dict(alphabetical=True).items() if k not in ignored_keys
398+
})
399+
400+
os.makedirs(op.dirname(save_path), exist_ok=True)
334401
with open(save_path, 'w') as fp:
335-
yaml.safe_dump(self.to_dict(alphabetical=True), fp)
402+
yaml.dump(serializable_dic, fp)
336403

337404
def copy(self):
338405
""" Create a deep copy of the Config object """
339-
return Config(copy.deepcopy(self.to_dict()))
406+
return Config(deepcopy(self.to_dict()))
340407

341408
# ---------------- Misc ----------------
342409

410+
def __copy__(self):
411+
return self.copy()
412+
413+
def __deepcopy__(self, memo):
414+
return self.copy()
415+
343416
def __repr__(self):
344417
return self.to_dict().__repr__()
345418

346419
def __str__(self):
347420
return self.to_dict().__repr__()
348421

349-
def _format(self):
422+
def string(self, alphabetical=False, ignored_keys=(), key_width=30, indent=0):
423+
ignored_keys = set(ignored_keys)
350424

351-
def _to_string(dic, indent=0):
425+
def _to_string(dic, idt=indent):
352426
texts = []
353-
for k, v in dic.items():
354-
if not isinstance(v, Config):
355-
texts += ['{:<25}{}'.format(' ' * indent + k + ':', str(v))]
427+
keys = sorted(dic.keys()) if alphabetical else dic.keys()
428+
keys = [k for k in keys if k not in ignored_keys]
429+
for k in keys:
430+
title = ' ' * idt + str(k) + ':'
431+
texts += ['{:<{}}'.format(title, key_width + idt)]
432+
if not isinstance(dic[k], Config):
433+
texts[-1] += str(dic[k])
356434
else:
357-
texts += [k + ':'] + _to_string(v, indent=indent + 2)
435+
texts += _to_string(dic[k], idt=idt + 2)
358436
return texts
359437

360438
return '\n'.join(_to_string(self))
361439

362-
def print(self, streamer=print):
363-
streamer(self._format())
440+
def print(self, streamer=print, alphabetical=False, ignored_keys=None, key_width=40, indent=0):
441+
return streamer(self.string(alphabetical, ignored_keys, key_width, indent))
364442

365443
def remove(self, key):
366444
""" Remove an attribute by its key. """
@@ -374,15 +452,15 @@ def remove(self, key):
374452

375453
del self[key]
376454

377-
# ---------------- Private ----------------
455+
# ---------------- Helpers ----------------
378456

379457
@classmethod
380458
def _from_dict(cls, dic):
381-
dic = copy.deepcopy(OrderedDict(dic))
459+
dic = deepcopy(OrderedDict(dic))
382460
for k, v in dic.items():
383461
if isinstance(v, dict):
384462
dic[k] = cls(v)
385-
elif isinstance(v, list): # load list as tuple for safety
463+
elif isinstance(v, (list, tuple)): # load list as tuple for safety
386464
dic[k] = tuple(cls(x) if isinstance(x, dict) else x for x in v)
387465

388466
return dic
@@ -419,7 +497,6 @@ def _separator_dict_to_nested_dict(separator_dict, separator='.'):
419497
:param separator_dict: a non-nested dict
420498
:return: a nested dict
421499
"""
422-
separator_dict = copy.deepcopy(separator_dict)
423500

424501
def _init_nested_dict():
425502
return defaultdict(_init_nested_dict)
@@ -432,7 +509,7 @@ def _default_to_dict(d):
432509

433510
nested_dict = _init_nested_dict()
434511

435-
for k, v in separator_dict.items():
512+
for k, v in deepcopy(separator_dict).items():
436513
tmp_d = nested_dict
437514
keys = k.split(separator)
438515
for sub_key in keys[:-1]:
@@ -473,7 +550,6 @@ def _nested_dict_to_separator_dict(nested_dict, separator='.'):
473550
:param nested_dict: a regular (optionally nested) dict
474551
:return: a non-nested dict whose keys contain separators
475552
"""
476-
nested_dict = copy.deepcopy(nested_dict)
477553

478554
def _create_separator_dict(x, key='', separator_dict={}):
479555
if isinstance(x, dict):
@@ -484,4 +560,30 @@ def _create_separator_dict(x, key='', separator_dict={}):
484560
separator_dict[key] = x
485561
return separator_dict
486562

487-
return _create_separator_dict(nested_dict)
563+
return _create_separator_dict(deepcopy(nested_dict))
564+
565+
@staticmethod
566+
def _unknown_args_to_dict(unknown_args):
567+
"""
568+
Convert unknown argument list returned by `parser.parse_known_args()` into a dict
569+
:param unknown_args: list of arguments, in which the keys must starts with '--' and
570+
the values could be any Python literal expression
571+
:return: a non-nested dict
572+
"""
573+
dic = {}
574+
575+
key, value_lst = None, []
576+
for item in unknown_args + ['--']:
577+
if item.startswith('--'):
578+
if key and value_lst:
579+
literal = ' '.join(value_lst)
580+
try:
581+
dic[key] = eval(literal)
582+
except SyntaxError as e:
583+
raise SyntaxError('invalid argument: --{} {}'.format(key, literal))
584+
585+
key, value_lst = item.replace('--', ''), []
586+
else:
587+
value_lst.append(item)
588+
589+
return dic

0 commit comments

Comments
 (0)