Skip to content

Commit 6d9499a

Browse files
authored
Merge pull request #50 from alxndrTL/mup-2
muP for Mamba and Mamba-2
2 parents 6cd7a11 + 6b0ec38 commit 6d9499a

File tree

7 files changed

+870
-23
lines changed

7 files changed

+870
-23
lines changed

mambapy/lm.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,67 @@ def __init__(self, model_config: Union[MambaConfig, Mamba2Config], vocab_size: i
4141
self.norm_f = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
4242

4343
self.lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
44-
self.embedding.weight = self.lm_head.weight
44+
# self.embedding.weight = self.lm_head.weight # weight-tying disabled
45+
46+
# muP custom initialization
47+
if self.config.mup and isinstance(self.config, MambaConfig):
48+
for pn, p in self.named_parameters():
49+
if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.x_proj.weight', 'mixer.dt_proj.weight', 'mixer.out_proj.weight']): # # "hidden weights"
50+
std = self.config.base_std
51+
52+
if 'mixer.out_proj.weight' in pn:
53+
std = std / math.sqrt(2 * self.config.n_layers) # scale down std of layers which projects onto the residual stream (not muP related)
54+
55+
if 'mixer.dt_proj.weight' in pn:
56+
std = self.config.dt_rank**-0.5 * self.config.dt_scale
57+
torch.nn.init.normal_(p, mean=0.0, std=std / math.sqrt(self.config.mup_width_mult))
58+
elif 'mixer.conv1d.weight' in pn:
59+
torch.nn.init.zeros_(p)
60+
elif pn == "embedding.weight":
61+
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std)
62+
elif pn == "lm_head.weight":
63+
torch.nn.init.zeros_(p)
64+
elif any(pn.endswith(w) for w in ['mixer.A_log', 'mixer.D']):
65+
# keep Mamba default init for these params
66+
pass
67+
else:
68+
# here, we only have biases
69+
assert p.dim() == 1, f"a 2d param ({pn}) has not been filtered out for init. please check."
70+
71+
if ("in_proj.bias" in pn) or ("out_proj.bias" in pn):
72+
torch.nn.init.zeros_(p)
73+
74+
elif self.config.mup and isinstance(self.config, Mamba2Config):
75+
for pn, p in self.named_parameters():
76+
77+
if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.out_proj.weight']): # # "hidden weights"
78+
std = self.config.base_std
79+
80+
if 'mixer.out_proj.weight' in pn:
81+
std = std / math.sqrt(2 * self.config.n_layers) # scale down std of layers which projects onto the residual stream (not muP related)
82+
83+
torch.nn.init.normal_(p, mean=0.0, std=std / math.sqrt(self.config.mup_width_mult))
84+
elif 'mixer.conv1d.weight' in pn:
85+
torch.nn.init.zeros_(p)
86+
elif pn == "embedding.weight":
87+
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std)
88+
elif pn == "lm_head.weight":
89+
torch.nn.init.zeros_(p)
90+
elif any(pn.endswith(w) for w in ['mixer.A_log', 'mixer.D', 'mixer.dt_bias']):
91+
# keep Mamba default init for these params
92+
pass
93+
else:
94+
# here, we only have biases
95+
assert p.dim() == 1, f"a 2d param ({pn}) has not been filtered out for init. please check."
4596

46-
self.apply(self._init_weights)
47-
for pn, p in self.named_parameters():
48-
if pn.endswith('fc_3.weight') or pn.endswith('c_proj.weight'):
49-
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std/math.sqrt(2 * self.config.n_layers))
97+
if ("in_proj.bias" in pn) or ("out_proj.bias" in pn):
98+
torch.nn.init.zeros_(p)
99+
100+
else:
101+
self.apply(self._init_weights)
102+
for pn, p in self.named_parameters():
103+
if pn.endswith('mixer.out_proj.weight'):
104+
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std/math.sqrt(2 * self.config.n_layers))
50105

51106
def forward(self, tokens):
52107
# tokens : (B, L)
@@ -142,13 +197,46 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
142197

143198
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
144199
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
145-
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
146-
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
200+
if self.config.mup and isinstance(self.config, MambaConfig):
201+
mup_params_keys = set([pn for pn in param_dict.keys() if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.x_proj.weight', 'mixer.dt_proj.weight', 'mixer.out_proj.weight'])])
202+
203+
dim2_params_keys = set([pn for pn in param_dict.keys() if param_dict[pn].dim() >= 2])
204+
dim2_params_keys = dim2_params_keys.difference(mup_params_keys)
205+
206+
mup_parameters = [p for n, p in param_dict.items() if n in mup_params_keys]
207+
decay_params = [p for n, p in param_dict.items() if n in dim2_params_keys]
208+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # biases and D
209+
210+
optim_groups = [
211+
{'params': mup_parameters, 'weight_decay': weight_decay * self.config.mup_width_mult, 'lr': learning_rate / self.config.mup_width_mult},
212+
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
213+
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
214+
]
215+
216+
elif self.config.mup and isinstance(self.config, Mamba2Config):
217+
mup_params_keys = set([pn for pn in param_dict.keys() if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.out_proj.weight'])])
218+
219+
dim2_params_keys = set([pn for pn in param_dict.keys() if param_dict[pn].dim() >= 2])
220+
dim2_params_keys = dim2_params_keys.difference(mup_params_keys)
221+
222+
mup_parameters = [p for n, p in param_dict.items() if n in mup_params_keys]
223+
decay_params = [p for n, p in param_dict.items() if n in dim2_params_keys]
224+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # biases and D and A
225+
226+
optim_groups = [
227+
{'params': mup_parameters, 'weight_decay': weight_decay * self.config.mup_width_mult, 'lr': learning_rate / self.config.mup_width_mult},
228+
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
229+
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
230+
]
231+
232+
else:
233+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
234+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
147235

148-
optim_groups = [
149-
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
150-
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
151-
]
236+
optim_groups = [
237+
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
238+
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
239+
]
152240

153241
# Create AdamW optimizer and use the fused version if it is available
154242
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters

mambapy/mamba.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class MambaConfig:
5050
conv_bias: bool = True
5151
inner_layernorms: bool = False # apply layernorms to internal activations
5252

53+
mup: bool = False
54+
mup_base_width: float = 128 # width=d_model
55+
5356
pscan: bool = True # use parallel scan mode or sequential mode when training
5457
use_cuda: bool = False # use official CUDA implementation when training (not compatible with (b)float16)
5558

@@ -59,6 +62,10 @@ def __post_init__(self):
5962
if self.dt_rank == 'auto':
6063
self.dt_rank = math.ceil(self.d_model / 16)
6164

65+
# muP
66+
if self.mup:
67+
self.mup_width_mult = self.d_model / self.mup_base_width
68+
6269
class Mamba(nn.Module):
6370
def __init__(self, config: MambaConfig):
6471
super().__init__()
@@ -94,7 +101,7 @@ def __init__(self, config: MambaConfig):
94101
super().__init__()
95102

96103
self.mixer = MambaBlock(config)
97-
self.norm = RMSNorm(config.d_model, config.rms_norm_eps)
104+
self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)
98105

99106
def forward(self, x):
100107
# x : (B, L, D)
@@ -170,9 +177,9 @@ def __init__(self, config: MambaConfig):
170177

171178
# used in jamba
172179
if self.config.inner_layernorms:
173-
self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps)
174-
self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps)
175-
self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps)
180+
self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps, config.mup)
181+
self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
182+
self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
176183
else:
177184
self.dt_layernorm = None
178185
self.B_layernorm = None
@@ -407,13 +414,21 @@ def ssm_step(self, x, h):
407414
return y, h
408415

409416
class RMSNorm(nn.Module):
410-
def __init__(self, d_model: int, eps: float = 1e-5):
417+
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
411418
super().__init__()
412419

420+
self.use_mup = use_mup
413421
self.eps = eps
414-
self.weight = nn.Parameter(torch.ones(d_model))
422+
423+
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
424+
if not use_mup:
425+
self.weight = nn.Parameter(torch.ones(d_model))
415426

416427
def forward(self, x):
417428
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
418429

419-
return output * self.weight
430+
if not self.use_mup:
431+
return output * self.weight
432+
else:
433+
return output
434+

mambapy/mamba2.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
55
adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py
6-
It justs implements a config similar to what's being done in mamba.py.
6+
It justs implements a config similar to what's being done in mamba.py, as well as supports muP.
77
88
"""
99

@@ -56,6 +56,9 @@ class Mamba2Config:
5656
bias: bool = False
5757
conv_bias: bool = True
5858

59+
mup: bool = False
60+
mup_base_width: float = 128 # width=d_model
61+
5962
chunk_size: int = 256
6063
use_mem_eff_path: bool = True
6164
dtype=None
@@ -68,6 +71,10 @@ def __post_init__(self):
6871

6972
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
7073

74+
# muP
75+
if self.mup:
76+
self.mup_width_mult = self.d_model / self.mup_base_width
77+
7178
class Mamba2(nn.Module):
7279
def __init__(self, config: Mamba2Config):
7380
super().__init__()
@@ -103,7 +110,7 @@ def __init__(self, config: Mamba2Config):
103110
super().__init__()
104111

105112
self.mixer = Mamba2Block(config)
106-
self.norm = RMSNorm(config.d_model, config.rms_norm_eps)
113+
self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)
107114

108115
def forward(self, x):
109116
# x : (B, L, D)
@@ -152,6 +159,7 @@ def __init__(self, config: Mamba2Config):
152159
nn.init.uniform_(self.conv1d.weight, -self.config.conv_init, self.config.conv_init)
153160
# self.conv1d.weight._no_weight_decay = True
154161

162+
# todo : mup init + lr
155163
if self.config.learnable_init_states:
156164
self.init_states = nn.Parameter(torch.zeros(self.config.n_heads, self.config.d_head, self.config.d_state, **factory_kwargs))
157165
self.init_states._no_weight_decay = True
@@ -267,12 +275,20 @@ def forward(self, u, seq_idx=None):
267275

268276
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
269277
class RMSNorm(nn.Module):
270-
def __init__(self, d_model: int, eps: float = 1e-5):
278+
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
271279
super().__init__()
272280

281+
self.use_mup = use_mup
273282
self.eps = eps
274-
self.weight = nn.Parameter(torch.ones(d_model))
283+
284+
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
285+
if not use_mup:
286+
self.weight = nn.Parameter(torch.ones(d_model))
275287

276288
def forward(self, x):
277289
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
278-
return output * self.weight
290+
291+
if not self.use_mup:
292+
return output * self.weight
293+
else:
294+
return output

0 commit comments

Comments
 (0)