Skip to content

Commit 90e2e73

Browse files
Refactors LayoutMap to use strict regex matching (#22164)
1 parent 2465b66 commit 90e2e73

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

keras/src/distribution/distribution_lib.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,10 @@ class ModelParallel(Distribution):
578578
# will be split across 4 devices. Any other variable that doesn't
579579
# match any key in the layout map will be fully replicated.
580580
layout_map = LayoutMap(device_mesh)
581-
layout_map['dense.*kernel'] = (None, 'model')
582-
layout_map['dense.*bias'] = ('model',)
583-
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
584-
layout_map['conv2d.*bias'] = ('model',)
581+
layout_map['.*dense.*kernel'] = (None, 'model')
582+
layout_map['.*dense.*bias'] = ('model',)
583+
layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')
584+
layout_map['.*conv2d.*bias'] = ('model',)
585585
586586
distribution = ModelParallel(
587587
layout_map=layout_map,
@@ -777,10 +777,10 @@ class LayoutMap(collections.abc.MutableMapping):
777777
778778
```python
779779
layout_map = LayoutMap(device_mesh)
780-
layout_map['dense.*kernel'] = (None, 'model')
781-
layout_map['dense.*bias'] = ('model',)
782-
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
783-
layout_map['conv2d.*bias'] = ('model',)
780+
layout_map['.*dense.*kernel'] = (None, 'model')
781+
layout_map['.*dense.*bias'] = ('model',)
782+
layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')
783+
layout_map['.*conv2d.*bias'] = ('model',)
784784
785785
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
786786
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
@@ -817,10 +817,11 @@ def __getitem__(self, key):
817817
if key in self._layout_map:
818818
return self._layout_map[key]
819819

820-
matching_keys = []
821-
for k in self._layout_map:
822-
if re.search(k, key):
823-
matching_keys.append(k)
820+
matching_keys = [
821+
pattern
822+
for pattern in self._layout_map
823+
if re.fullmatch(pattern, key)
824+
]
824825
if len(matching_keys) > 1:
825826
raise ValueError(
826827
f"Path '{key}' matches multiple layout "

keras/src/distribution/distribution_lib_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,10 @@ def test_get(self):
412412
layout_map["dense/kernel"] = self.sharded_2d
413413
layout_map["dense/bias"] = self.sharded_1d
414414

415-
layout_map["dense.*kernel"] = self.replicated_2d
416-
layout_map["dense.*bias"] = self.replicated_1d
415+
layout_map[".*dense.*kernel"] = self.replicated_2d
416+
layout_map[".*dense.*bias"] = self.replicated_1d
417417

418-
layout_map["bias"] = self.sharded_1d
418+
layout_map[".*bias"] = self.sharded_1d
419419

420420
self.assertEqual(layout_map["dense/kernel"], self.sharded_2d)
421421
self.assertEqual(layout_map["dense/bias"], self.sharded_1d)

0 commit comments

Comments
 (0)