44# Author: Qiu Jueqin (qiujueqin@gmail.com)
55
66
7+ import os
78import os .path as op
8- import copy
99import argparse
1010import pathlib
11+ from copy import deepcopy
1112from collections import OrderedDict , defaultdict
1213from contextlib import contextmanager
1314
5556
5657class Config (OrderedDict ):
5758
58- def __init__ (self , init = None ):
59+ def __init__ (self , init = None , ** kwargs ):
5960 """
6061 :param init: dict | yaml filepath | argparse.Namespace
6162 """
@@ -69,7 +70,7 @@ def __init__(self, init=None):
6970 elif isinstance (init , str ):
7071 self .from_yaml (init )
7172 elif isinstance (init , argparse .Namespace ):
72- self .from_namespace (init )
73+ self .from_namespace (init , ** kwargs )
7374 else :
7475 raise TypeError (
7576 f'Config could only be instantiated from a dict, a yaml '
@@ -106,10 +107,17 @@ def unfreeze(self):
106107 def _set_immutable (self , is_immutable ):
107108 """ Recursively set immutability. """
108109
109- self .__dict__ ['__immutable__' ] = is_immutable
110- for v in self .values ():
111- if isinstance (v , Config ):
112- v ._set_immutable (is_immutable )
110+ def _recursively_set_immutable (obj ):
111+ if isinstance (obj , dict ):
112+ if isinstance (obj , Config ):
113+ obj .__dict__ ['__immutable__' ] = is_immutable
114+ for v in obj .values ():
115+ _recursively_set_immutable (v )
116+ elif isinstance (obj , (list , tuple )):
117+ for item in obj :
118+ _recursively_set_immutable (item )
119+
120+ _recursively_set_immutable (self )
113121
114122 # ---------------- Set & Get ----------------
115123
@@ -163,7 +171,7 @@ def from_yaml(self, yaml_path):
163171 super ().__init__ (Config ._from_dict (dic ))
164172 self .freeze ()
165173
166- def from_namespace (self , namespace ):
174+ def from_namespace (self , parsed_args , unknown_args = None ):
167175 """
168176 Instantiation from an argparse.Namespace object.
169177
@@ -181,40 +189,77 @@ def from_namespace(self, namespace):
181189
182190 Given the returned argparse.Namespace object 'args', from_namespace()
183191 will create a Config object as if it was instantiated from a nested
184- dict d = {'foo': {'bar': 42}}
192+ dict d = {'foo': {'bar': 42}}.
193+
194+ Optionally, the extra argument `unknown_args` also accepts unknown
195+ arguments by parser.parse_known_args(), but note that the arguments
196+ in command line must starts with '--'.
197+
198+ For example, creating an argparse.ArgumentParser with '--foo'
199+ argument:
200+
201+ >>> parser = argparse.ArgumentParser()
202+ >>> parser.add_argument('--foo', type=int, default=0)
203+
204+ but in command line user also inputs other arguments:
205+
206+ ```
207+ python main.py --foo 42 --bar ['Alice', 'Bob']
208+ ```
209+
210+ Given the returned argparse.Namespace object `parsed` and unknown
211+ args list `unknown` by calling
212+
213+ >>> parsed, unknown = parser.parse_known_args()
214+
215+ `from_namespace(parsed, unknown_args=unknown)` will create a Config
216+ object as if it was instantiated from a dict
217+ d = {'foo': 42, 'bar': ['Alice', 'Bob']}.
185218 """
186219
187- if not isinstance (namespace , argparse .Namespace ):
220+ if not isinstance (parsed_args , argparse .Namespace ):
188221 raise TypeError (
189222 f'expected an argparse.Namespace object, but given a '
190- f'{ type (namespace )} '
223+ f'{ type (parsed_args )} '
224+ )
225+
226+ nested_dict = self ._separator_dict_to_nested_dict (vars (parsed_args ))
227+
228+ if unknown_args :
229+ nested_dict .update (
230+ self ._separator_dict_to_nested_dict (self ._unknown_args_to_dict (unknown_args ))
191231 )
192232
193- nested_dict = self ._separator_dict_to_nested_dict (vars (namespace ))
194233 super ().__init__ (Config ._from_dict (nested_dict ))
195234 self .freeze ()
196235
197- def merge (self , other , allow_new_attr = False , keep_existed_attr = True ):
236+ def merge (self , other ,
237+ exclusive = True ,
238+ max_exclusive_depth = float ('Inf' ),
239+ keep_existed_attr = True ):
198240 """
199241 Recursively merge from other object
200242
201243 :param other: Config object | dict | yaml filepath |
202244 argparse.Namespace object
203- :param allow_new_attr: whether allow to add new attributes
245+ :param exclusive: if set to True, merging with new fields is forbidden
204246
205247 Example:
206248
207249 >>> cfg = Config({'optimizer': 'adam'})
208- >>> cfg.merge({'lr': 0.001}, allow_new_attr=True )
250+ >>> cfg.merge({'lr': 0.001}, exclusive=False )
209251 >>> cfg.print()
210252
211253 optimizer: adam
212254 lr: 0.001
213255
214- >>> cfg.merge({'weight_decay': 1E-7}, allow_new_attr=False )
256+ >>> cfg.merge({'weight_decay': 1E-7}, exclusive=True )
215257
216258 AttributeError: attempted to add a new attribute: weight_decay
217259
260+ :param max_exclusive_depth: max depth to prevent from merging new
261+ attributes, only valid when exclusive=True. Set to 0 is equal to
262+ exclusive=False
218263 :param keep_existed_attr: whether keep those attributes that are not
219264 in 'other'. You may wish to trigger this if requires to completely
220265 replace a child Config object. See example/examples.py: Example 5
@@ -225,15 +270,15 @@ def merge(self, other, allow_new_attr=False, keep_existed_attr=True):
225270 >>> cfg1 = Config({'foo': {'Alice': 0, 'Bob': 1}})
226271 >>> cfg2 = cfg1.copy()
227272 >>> another = {'foo': {'Carol': 42}}
228- >>> cfg1.merge(another, allow_new_attr=True )
273+ >>> cfg1.merge(another, exclusive=False )
229274 >>> cfg1.print()
230275
231276 foo:
232277 Alice: 0
233278 Bob: 1
234279 Carol: 42
235280
236- >>> cfg2.merge(another, allow_new_attr=True , keep_existed_attr=False)
281+ >>> cfg2.merge(another, exclusive=False , keep_existed_attr=False)
237282 >>> cfg2.print()
238283
239284 foo:
@@ -249,25 +294,24 @@ def merge(self, other, allow_new_attr=False, keep_existed_attr=True):
249294 f'attempted to merge from an unsupported { type (other )} object'
250295 )
251296
252- def _merge (source_cfg , other_cfg , add_new , keep_existed ):
297+ def _merge (source_cfg , other_cfg , excl , keep_existed , _cur_depth = 1 ):
253298 """ Recursively merge the new Config object into the source one """
254-
255- with source_cfg .unfreeze ():
299+ with source_cfg .unfreeze (), other_cfg .unfreeze ():
256300 for k , v in other_cfg .items ():
257- if k not in source_cfg and not add_new :
301+ if k not in source_cfg and excl and _cur_depth <= max_exclusive_depth :
258302 raise AttributeError (
259- f'attempted to add an attribute { k } but it is not '
260- f'found in the source Config. Set `allow_new_attr` '
261- f'to True if requires to add new attributes'
303+ f'attempted to merge an attribute ` { k } ` that is not '
304+ f'found in the source Config. Set `exclusive` to False '
305+ f'if requires to add new attributes'
262306 )
263307
264308 if isinstance (v , Config ):
265- if isinstance (source_cfg .get (k , None ), Config ):
266- _merge (source_cfg [k ], v , add_new , keep_existed )
309+ if isinstance (source_cfg .get (k ), Config ):
310+ _merge (source_cfg [k ], v , excl , keep_existed , _cur_depth = _cur_depth + 1 )
267311 else :
268312 source_cfg [k ] = v
269313 else :
270- source_cfg [k ] = copy . deepcopy (v )
314+ source_cfg [k ] = deepcopy (v )
271315
272316 if not keep_existed :
273317 source_keys = list (source_cfg .keys ())
@@ -276,7 +320,7 @@ def _merge(source_cfg, other_cfg, add_new, keep_existed):
276320 not isinstance (source_cfg [k ], Config ):
277321 source_cfg .remove (k )
278322
279- _merge (self , other , allow_new_attr , keep_existed_attr )
323+ _merge (self , other , exclusive , keep_existed_attr )
280324
281325 # ---------------- Output ----------------
282326
@@ -286,16 +330,20 @@ def to_dict(self, alphabetical=False):
286330 An inverse method to self.from_dict()
287331 """
288332
289- def _to_dict (config ):
290- if alphabetical :
291- config = sorted (config .items (), key = lambda x : x [0 ])
292- dic = dict (config )
293- for k , v in dic .items ():
294- if isinstance (v , Config ):
295- dic [k ] = _to_dict (v )
296- return dic
333+ def _recursively_to_dict (obj ):
334+ if isinstance (obj , Config ):
335+ if alphabetical :
336+ obj = sorted (obj .items (), key = lambda x : x [0 ])
337+ dic = dict (obj )
338+ for k , v in dic .items ():
339+ dic [k ] = _recursively_to_dict (v )
340+ return dic
341+ elif isinstance (obj , (list , tuple )):
342+ return tuple (_recursively_to_dict (item ) for item in obj )
343+ else :
344+ return obj
297345
298- return _to_dict (self )
346+ return _recursively_to_dict (self )
299347
300348 def to_parser (self ):
301349 """
@@ -329,38 +377,68 @@ def to_parser(self):
329377
330378 return parser
331379
332- def dump (self , save_path ):
380+ def dump (self , save_path , ignored_keys = () ):
333381 """ Dump a Config object into a yaml file """
382+ if not save_path .endswith ('.yaml' ):
383+ raise TypeError ('only yaml file is supported by dump() method' )
384+
385+ def _serialize (obj ):
386+ serializable_types = (bool , str , int , float , list , tuple , dict , set , type (None ))
387+ if not isinstance (obj , serializable_types ):
388+ return '{} <class \' {}\' >' .format (str (obj ), obj .__class__ .__name__ )
389+ elif isinstance (obj , dict ):
390+ return {k : _serialize (v ) for k , v in obj .items ()}
391+ elif isinstance (obj , (list , tuple )):
392+ return [_serialize (item ) for item in obj ]
393+ else :
394+ return obj
395+
396+ serializable_dic = _serialize ({
397+ k : v for k , v in self .to_dict (alphabetical = True ).items () if k not in ignored_keys
398+ })
399+
400+ os .makedirs (op .dirname (save_path ), exist_ok = True )
334401 with open (save_path , 'w' ) as fp :
335- yaml .safe_dump ( self . to_dict ( alphabetical = True ) , fp )
402+ yaml .dump ( serializable_dic , fp )
336403
337404 def copy (self ):
338405 """ Create a deep copy of the Config object """
339- return Config (copy . deepcopy (self .to_dict ()))
406+ return Config (deepcopy (self .to_dict ()))
340407
341408 # ---------------- Misc ----------------
342409
410+ def __copy__ (self ):
411+ return self .copy ()
412+
413+ def __deepcopy__ (self , memo ):
414+ return self .copy ()
415+
343416 def __repr__ (self ):
344417 return self .to_dict ().__repr__ ()
345418
346419 def __str__ (self ):
347420 return self .to_dict ().__repr__ ()
348421
349- def _format (self ):
422+ def string (self , alphabetical = False , ignored_keys = (), key_width = 30 , indent = 0 ):
423+ ignored_keys = set (ignored_keys )
350424
351- def _to_string (dic , indent = 0 ):
425+ def _to_string (dic , idt = indent ):
352426 texts = []
353- for k , v in dic .items ():
354- if not isinstance (v , Config ):
355- texts += ['{:<25}{}' .format (' ' * indent + k + ':' , str (v ))]
427+ keys = sorted (dic .keys ()) if alphabetical else dic .keys ()
428+ keys = [k for k in keys if k not in ignored_keys ]
429+ for k in keys :
430+ title = ' ' * idt + str (k ) + ':'
431+ texts += ['{:<{}}' .format (title , key_width + idt )]
432+ if not isinstance (dic [k ], Config ):
433+ texts [- 1 ] += str (dic [k ])
356434 else :
357- texts += [ k + ':' ] + _to_string ( v , indent = indent + 2 )
435+ texts += _to_string ( dic [ k ], idt = idt + 2 )
358436 return texts
359437
360438 return '\n ' .join (_to_string (self ))
361439
362- def print (self , streamer = print ):
363- streamer (self ._format ( ))
440+ def print (self , streamer = print , alphabetical = False , ignored_keys = None , key_width = 40 , indent = 0 ):
441+ return streamer (self .string ( alphabetical , ignored_keys , key_width , indent ))
364442
365443 def remove (self , key ):
366444 """ Remove an attribute by its key. """
@@ -374,15 +452,15 @@ def remove(self, key):
374452
375453 del self [key ]
376454
377- # ---------------- Private ----------------
455+ # ---------------- Helpers ----------------
378456
379457 @classmethod
380458 def _from_dict (cls , dic ):
381- dic = copy . deepcopy (OrderedDict (dic ))
459+ dic = deepcopy (OrderedDict (dic ))
382460 for k , v in dic .items ():
383461 if isinstance (v , dict ):
384462 dic [k ] = cls (v )
385- elif isinstance (v , list ): # load list as tuple for safety
463+ elif isinstance (v , ( list , tuple ) ): # load list as tuple for safety
386464 dic [k ] = tuple (cls (x ) if isinstance (x , dict ) else x for x in v )
387465
388466 return dic
@@ -419,7 +497,6 @@ def _separator_dict_to_nested_dict(separator_dict, separator='.'):
419497 :param separator_dict: a non-nested dict
420498 :return: a nested dict
421499 """
422- separator_dict = copy .deepcopy (separator_dict )
423500
424501 def _init_nested_dict ():
425502 return defaultdict (_init_nested_dict )
@@ -432,7 +509,7 @@ def _default_to_dict(d):
432509
433510 nested_dict = _init_nested_dict ()
434511
435- for k , v in separator_dict .items ():
512+ for k , v in deepcopy ( separator_dict ) .items ():
436513 tmp_d = nested_dict
437514 keys = k .split (separator )
438515 for sub_key in keys [:- 1 ]:
@@ -473,7 +550,6 @@ def _nested_dict_to_separator_dict(nested_dict, separator='.'):
473550 :param nested_dict: a regular (optionally nested) dict
474551 :return: a non-nested dict whose keys contain separators
475552 """
476- nested_dict = copy .deepcopy (nested_dict )
477553
478554 def _create_separator_dict (x , key = '' , separator_dict = {}):
479555 if isinstance (x , dict ):
@@ -484,4 +560,30 @@ def _create_separator_dict(x, key='', separator_dict={}):
484560 separator_dict [key ] = x
485561 return separator_dict
486562
487- return _create_separator_dict (nested_dict )
563+ return _create_separator_dict (deepcopy (nested_dict ))
564+
565+ @staticmethod
566+ def _unknown_args_to_dict (unknown_args ):
567+ """
568+ Convert unknown argument list returned by `parser.parse_known_args()` into a dict
569+ :param unknown_args: list of arguments, in which the keys must starts with '--' and
570+ the values could be any Python literal expression
571+ :return: a non-nested dict
572+ """
573+ dic = {}
574+
575+ key , value_lst = None , []
576+ for item in unknown_args + ['--' ]:
577+ if item .startswith ('--' ):
578+ if key and value_lst :
579+ literal = ' ' .join (value_lst )
580+ try :
581+ dic [key ] = eval (literal )
582+ except SyntaxError as e :
583+ raise SyntaxError ('invalid argument: --{} {}' .format (key , literal ))
584+
585+ key , value_lst = item .replace ('--' , '' ), []
586+ else :
587+ value_lst .append (item )
588+
589+ return dic
0 commit comments