Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions mambapy/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,67 @@ def __init__(self, model_config: Union[MambaConfig, Mamba2Config], vocab_size: i
self.norm_f = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)

self.lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
self.embedding.weight = self.lm_head.weight
# self.embedding.weight = self.lm_head.weight # weight-tying disabled

# muP custom initialization
if self.config.mup and isinstance(self.config, MambaConfig):
for pn, p in self.named_parameters():
if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.x_proj.weight', 'mixer.dt_proj.weight', 'mixer.out_proj.weight']): # # "hidden weights"
std = self.config.base_std

if 'mixer.out_proj.weight' in pn:
std = std / math.sqrt(2 * self.config.n_layers) # scale down std of layers which projects onto the residual stream (not muP related)

if 'mixer.dt_proj.weight' in pn:
std = self.config.dt_rank**-0.5 * self.config.dt_scale
torch.nn.init.normal_(p, mean=0.0, std=std / math.sqrt(self.config.mup_width_mult))
elif 'mixer.conv1d.weight' in pn:
torch.nn.init.zeros_(p)
elif pn == "embedding.weight":
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std)
elif pn == "lm_head.weight":
torch.nn.init.zeros_(p)
elif any(pn.endswith(w) for w in ['mixer.A_log', 'mixer.D']):
# keep Mamba default init for these params
pass
else:
# here, we only have biases
assert p.dim() == 1, f"a 2d param ({pn}) has not been filtered out for init. please check."

if ("in_proj.bias" in pn) or ("out_proj.bias" in pn):
torch.nn.init.zeros_(p)

elif self.config.mup and isinstance(self.config, Mamba2Config):
for pn, p in self.named_parameters():

if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.out_proj.weight']): # # "hidden weights"
std = self.config.base_std

if 'mixer.out_proj.weight' in pn:
std = std / math.sqrt(2 * self.config.n_layers) # scale down std of layers which projects onto the residual stream (not muP related)

torch.nn.init.normal_(p, mean=0.0, std=std / math.sqrt(self.config.mup_width_mult))
elif 'mixer.conv1d.weight' in pn:
torch.nn.init.zeros_(p)
elif pn == "embedding.weight":
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std)
elif pn == "lm_head.weight":
torch.nn.init.zeros_(p)
elif any(pn.endswith(w) for w in ['mixer.A_log', 'mixer.D', 'mixer.dt_bias']):
# keep Mamba default init for these params
pass
else:
# here, we only have biases
assert p.dim() == 1, f"a 2d param ({pn}) has not been filtered out for init. please check."

self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('fc_3.weight') or pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std/math.sqrt(2 * self.config.n_layers))
if ("in_proj.bias" in pn) or ("out_proj.bias" in pn):
torch.nn.init.zeros_(p)

else:
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('mixer.out_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=self.config.base_std/math.sqrt(2 * self.config.n_layers))

def forward(self, tokens):
# tokens : (B, L)
Expand Down Expand Up @@ -142,13 +197,46 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):

# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
if self.config.mup and isinstance(self.config, MambaConfig):
mup_params_keys = set([pn for pn in param_dict.keys() if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.x_proj.weight', 'mixer.dt_proj.weight', 'mixer.out_proj.weight'])])

dim2_params_keys = set([pn for pn in param_dict.keys() if param_dict[pn].dim() >= 2])
dim2_params_keys = dim2_params_keys.difference(mup_params_keys)

mup_parameters = [p for n, p in param_dict.items() if n in mup_params_keys]
decay_params = [p for n, p in param_dict.items() if n in dim2_params_keys]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # biases and D

optim_groups = [
{'params': mup_parameters, 'weight_decay': weight_decay * self.config.mup_width_mult, 'lr': learning_rate / self.config.mup_width_mult},
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
]

elif self.config.mup and isinstance(self.config, Mamba2Config):
mup_params_keys = set([pn for pn in param_dict.keys() if any(pn.endswith(w) for w in ['mixer.in_proj.weight', 'mixer.out_proj.weight'])])

dim2_params_keys = set([pn for pn in param_dict.keys() if param_dict[pn].dim() >= 2])
dim2_params_keys = dim2_params_keys.difference(mup_params_keys)

mup_parameters = [p for n, p in param_dict.items() if n in mup_params_keys]
decay_params = [p for n, p in param_dict.items() if n in dim2_params_keys]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] # biases and D and A

optim_groups = [
{'params': mup_parameters, 'weight_decay': weight_decay * self.config.mup_width_mult, 'lr': learning_rate / self.config.mup_width_mult},
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
]

else:
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]

optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay, 'lr': learning_rate},
{'params': nodecay_params, 'weight_decay': 0.0, 'lr': learning_rate}
]

# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
Expand Down
29 changes: 22 additions & 7 deletions mambapy/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class MambaConfig:
conv_bias: bool = True
inner_layernorms: bool = False # apply layernorms to internal activations

mup: bool = False
mup_base_width: float = 128 # width=d_model

pscan: bool = True # use parallel scan mode or sequential mode when training
use_cuda: bool = False # use official CUDA implementation when training (not compatible with (b)float16)

Expand All @@ -59,6 +62,10 @@ def __post_init__(self):
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)

# muP
if self.mup:
self.mup_width_mult = self.d_model / self.mup_base_width

class Mamba(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
Expand Down Expand Up @@ -94,7 +101,7 @@ def __init__(self, config: MambaConfig):
super().__init__()

self.mixer = MambaBlock(config)
self.norm = RMSNorm(config.d_model, config.rms_norm_eps)
self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)

def forward(self, x):
# x : (B, L, D)
Expand Down Expand Up @@ -170,9 +177,9 @@ def __init__(self, config: MambaConfig):

# used in jamba
if self.config.inner_layernorms:
self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps)
self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps)
self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps)
self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps, config.mup)
self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
else:
self.dt_layernorm = None
self.B_layernorm = None
Expand Down Expand Up @@ -407,13 +414,21 @@ def ssm_step(self, x, h):
return y, h

class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
super().__init__()

self.use_mup = use_mup
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))

# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
if not use_mup:
self.weight = nn.Parameter(torch.ones(d_model))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

return output * self.weight
if not self.use_mup:
return output * self.weight
else:
return output

26 changes: 21 additions & 5 deletions mambapy/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py
It justs implements a config similar to what's being done in mamba.py.
It justs implements a config similar to what's being done in mamba.py, as well as supports muP.

"""

Expand Down Expand Up @@ -56,6 +56,9 @@ class Mamba2Config:
bias: bool = False
conv_bias: bool = True

mup: bool = False
mup_base_width: float = 128 # width=d_model

chunk_size: int = 256
use_mem_eff_path: bool = True
dtype=None
Expand All @@ -68,6 +71,10 @@ def __post_init__(self):

assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"

# muP
if self.mup:
self.mup_width_mult = self.d_model / self.mup_base_width

class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
Expand Down Expand Up @@ -103,7 +110,7 @@ def __init__(self, config: Mamba2Config):
super().__init__()

self.mixer = Mamba2Block(config)
self.norm = RMSNorm(config.d_model, config.rms_norm_eps)
self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)

def forward(self, x):
# x : (B, L, D)
Expand Down Expand Up @@ -152,6 +159,7 @@ def __init__(self, config: Mamba2Config):
nn.init.uniform_(self.conv1d.weight, -self.config.conv_init, self.config.conv_init)
# self.conv1d.weight._no_weight_decay = True

# todo : mup init + lr
if self.config.learnable_init_states:
self.init_states = nn.Parameter(torch.zeros(self.config.n_heads, self.config.d_head, self.config.d_state, **factory_kwargs))
self.init_states._no_weight_decay = True
Expand Down Expand Up @@ -267,12 +275,20 @@ def forward(self, u, seq_idx=None):

# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
super().__init__()

self.use_mup = use_mup
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))

# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
if not use_mup:
self.weight = nn.Parameter(torch.ones(d_model))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight

if not self.use_mup:
return output * self.weight
else:
return output
Loading