@@ -578,10 +578,10 @@ class ModelParallel(Distribution):
578578 # will be split across 4 devices. Any other variable that doesn't
579579 # match any key in the layout map will be fully replicated.
580580 layout_map = LayoutMap(device_mesh)
581- layout_map['dense.*kernel'] = (None, 'model')
582- layout_map['dense.*bias'] = ('model',)
583- layout_map['conv2d.*kernel'] = (None, None, None, 'model')
584- layout_map['conv2d.*bias'] = ('model',)
581+ layout_map['.* dense.*kernel'] = (None, 'model')
582+ layout_map['.* dense.*bias'] = ('model',)
583+ layout_map['.* conv2d.*kernel'] = (None, None, None, 'model')
584+ layout_map['.* conv2d.*bias'] = ('model',)
585585
586586 distribution = ModelParallel(
587587 layout_map=layout_map,
@@ -777,10 +777,10 @@ class LayoutMap(collections.abc.MutableMapping):
777777
778778 ```python
779779 layout_map = LayoutMap(device_mesh)
780- layout_map['dense.*kernel'] = (None, 'model')
781- layout_map['dense.*bias'] = ('model',)
782- layout_map['conv2d.*kernel'] = (None, None, None, 'model')
783- layout_map['conv2d.*bias'] = ('model',)
780+ layout_map['.* dense.*kernel'] = (None, 'model')
781+ layout_map['.* dense.*bias'] = ('model',)
782+ layout_map['.* conv2d.*kernel'] = (None, None, None, 'model')
783+ layout_map['.* conv2d.*bias'] = ('model',)
784784
785785 layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
786786 layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
@@ -817,10 +817,11 @@ def __getitem__(self, key):
817817 if key in self ._layout_map :
818818 return self ._layout_map [key ]
819819
820- matching_keys = []
821- for k in self ._layout_map :
822- if re .search (k , key ):
823- matching_keys .append (k )
820+ matching_keys = [
821+ pattern
822+ for pattern in self ._layout_map
823+ if re .fullmatch (pattern , key )
824+ ]
824825 if len (matching_keys ) > 1 :
825826 raise ValueError (
826827 f"Path '{ key } ' matches multiple layout "
0 commit comments