@@ -41,12 +41,67 @@ def __init__(self, model_config: Union[MambaConfig, Mamba2Config], vocab_size: i
4141 self .norm_f = RMSNorm (self .config .d_model , self .config .rms_norm_eps , self .config .mup )
4242
4343 self .lm_head = nn .Linear (self .config .d_model , vocab_size , bias = False )
44- self .embedding .weight = self .lm_head .weight
44+ # self.embedding.weight = self.lm_head.weight # weight-tying disabled
45+
46+ # muP custom initialization
47+ if self .config .mup and isinstance (self .config , MambaConfig ):
48+ for pn , p in self .named_parameters ():
49+ if any (pn .endswith (w ) for w in ['mixer.in_proj.weight' , 'mixer.x_proj.weight' , 'mixer.dt_proj.weight' , 'mixer.out_proj.weight' ]): # # "hidden weights"
50+ std = self .config .base_std
51+
52+ if 'mixer.out_proj.weight' in pn :
53+ std = std / math .sqrt (2 * self .config .n_layers ) # scale down std of layers which projects onto the residual stream (not muP related)
54+
55+ if 'mixer.dt_proj.weight' in pn :
56+ std = self .config .dt_rank ** - 0.5 * self .config .dt_scale
57+ torch .nn .init .normal_ (p , mean = 0.0 , std = std / math .sqrt (self .config .mup_width_mult ))
58+ elif 'mixer.conv1d.weight' in pn :
59+ torch .nn .init .zeros_ (p )
60+ elif pn == "embedding.weight" :
61+ torch .nn .init .normal_ (p , mean = 0.0 , std = self .config .base_std )
62+ elif pn == "lm_head.weight" :
63+ torch .nn .init .zeros_ (p )
64+ elif any (pn .endswith (w ) for w in ['mixer.A_log' , 'mixer.D' ]):
65+ # keep Mamba default init for these params
66+ pass
67+ else :
68+ # here, we only have biases
69+ assert p .dim () == 1 , f"a 2d param ({ pn } ) has not been filtered out for init. please check."
70+
71+ if ("in_proj.bias" in pn ) or ("out_proj.bias" in pn ):
72+ torch .nn .init .zeros_ (p )
73+
74+ elif self .config .mup and isinstance (self .config , Mamba2Config ):
75+ for pn , p in self .named_parameters ():
76+
77+ if any (pn .endswith (w ) for w in ['mixer.in_proj.weight' , 'mixer.out_proj.weight' ]): # # "hidden weights"
78+ std = self .config .base_std
79+
80+ if 'mixer.out_proj.weight' in pn :
81+ std = std / math .sqrt (2 * self .config .n_layers ) # scale down std of layers which projects onto the residual stream (not muP related)
82+
83+ torch .nn .init .normal_ (p , mean = 0.0 , std = std / math .sqrt (self .config .mup_width_mult ))
84+ elif 'mixer.conv1d.weight' in pn :
85+ torch .nn .init .zeros_ (p )
86+ elif pn == "embedding.weight" :
87+ torch .nn .init .normal_ (p , mean = 0.0 , std = self .config .base_std )
88+ elif pn == "lm_head.weight" :
89+ torch .nn .init .zeros_ (p )
90+ elif any (pn .endswith (w ) for w in ['mixer.A_log' , 'mixer.D' , 'mixer.dt_bias' ]):
91+ # keep Mamba default init for these params
92+ pass
93+ else :
94+ # here, we only have biases
95+ assert p .dim () == 1 , f"a 2d param ({ pn } ) has not been filtered out for init. please check."
4596
46- self .apply (self ._init_weights )
47- for pn , p in self .named_parameters ():
48- if pn .endswith ('fc_3.weight' ) or pn .endswith ('c_proj.weight' ):
49- torch .nn .init .normal_ (p , mean = 0.0 , std = self .config .base_std / math .sqrt (2 * self .config .n_layers ))
97+ if ("in_proj.bias" in pn ) or ("out_proj.bias" in pn ):
98+ torch .nn .init .zeros_ (p )
99+
100+ else :
101+ self .apply (self ._init_weights )
102+ for pn , p in self .named_parameters ():
103+ if pn .endswith ('mixer.out_proj.weight' ):
104+ torch .nn .init .normal_ (p , mean = 0.0 , std = self .config .base_std / math .sqrt (2 * self .config .n_layers ))
50105
51106 def forward (self , tokens ):
52107 # tokens : (B, L)
@@ -142,13 +197,46 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
142197
143198 # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
144199 # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
145- decay_params = [p for n , p in param_dict .items () if p .dim () >= 2 ]
146- nodecay_params = [p for n , p in param_dict .items () if p .dim () < 2 ]
200+ if self .config .mup and isinstance (self .config , MambaConfig ):
201+ mup_params_keys = set ([pn for pn in param_dict .keys () if any (pn .endswith (w ) for w in ['mixer.in_proj.weight' , 'mixer.x_proj.weight' , 'mixer.dt_proj.weight' , 'mixer.out_proj.weight' ])])
202+
203+ dim2_params_keys = set ([pn for pn in param_dict .keys () if param_dict [pn ].dim () >= 2 ])
204+ dim2_params_keys = dim2_params_keys .difference (mup_params_keys )
205+
206+ mup_parameters = [p for n , p in param_dict .items () if n in mup_params_keys ]
207+ decay_params = [p for n , p in param_dict .items () if n in dim2_params_keys ]
208+ nodecay_params = [p for n , p in param_dict .items () if p .dim () < 2 ] # biases and D
209+
210+ optim_groups = [
211+ {'params' : mup_parameters , 'weight_decay' : weight_decay * self .config .mup_width_mult , 'lr' : learning_rate / self .config .mup_width_mult },
212+ {'params' : decay_params , 'weight_decay' : weight_decay , 'lr' : learning_rate },
213+ {'params' : nodecay_params , 'weight_decay' : 0.0 , 'lr' : learning_rate }
214+ ]
215+
216+ elif self .config .mup and isinstance (self .config , Mamba2Config ):
217+ mup_params_keys = set ([pn for pn in param_dict .keys () if any (pn .endswith (w ) for w in ['mixer.in_proj.weight' , 'mixer.out_proj.weight' ])])
218+
219+ dim2_params_keys = set ([pn for pn in param_dict .keys () if param_dict [pn ].dim () >= 2 ])
220+ dim2_params_keys = dim2_params_keys .difference (mup_params_keys )
221+
222+ mup_parameters = [p for n , p in param_dict .items () if n in mup_params_keys ]
223+ decay_params = [p for n , p in param_dict .items () if n in dim2_params_keys ]
224+ nodecay_params = [p for n , p in param_dict .items () if p .dim () < 2 ] # biases and D and A
225+
226+ optim_groups = [
227+ {'params' : mup_parameters , 'weight_decay' : weight_decay * self .config .mup_width_mult , 'lr' : learning_rate / self .config .mup_width_mult },
228+ {'params' : decay_params , 'weight_decay' : weight_decay , 'lr' : learning_rate },
229+ {'params' : nodecay_params , 'weight_decay' : 0.0 , 'lr' : learning_rate }
230+ ]
231+
232+ else :
233+ decay_params = [p for n , p in param_dict .items () if p .dim () >= 2 ]
234+ nodecay_params = [p for n , p in param_dict .items () if p .dim () < 2 ]
147235
148- optim_groups = [
149- {'params' : decay_params , 'weight_decay' : weight_decay , 'lr' : learning_rate },
150- {'params' : nodecay_params , 'weight_decay' : 0.0 , 'lr' : learning_rate }
151- ]
236+ optim_groups = [
237+ {'params' : decay_params , 'weight_decay' : weight_decay , 'lr' : learning_rate },
238+ {'params' : nodecay_params , 'weight_decay' : 0.0 , 'lr' : learning_rate }
239+ ]
152240
153241 # Create AdamW optimizer and use the fused version if it is available
154242 fused_available = 'fused' in inspect .signature (torch .optim .AdamW ).parameters
0 commit comments