Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
165 commits
Select commit Hold shift + click to select a range
5137757
stuff
jlamypoirier Mar 26, 2025
f0cb32a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Mar 26, 2025
f26010e
Update pretrained config
jlamypoirier Mar 27, 2025
b930a39
stuff
jlamypoirier Mar 27, 2025
918a7a8
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
8117c47
fixes
jlamypoirier Mar 27, 2025
1c995d3
fix
jlamypoirier Mar 27, 2025
3f90475
Merge branch 'main' into config_updates
jlamypoirier Mar 27, 2025
e389058
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
506fe92
fixes
jlamypoirier Mar 27, 2025
971d3ef
fixes
jlamypoirier Mar 27, 2025
6bf20cb
Tests wip
jlamypoirier Mar 28, 2025
c13fb19
misc
jlamypoirier Mar 29, 2025
a20fcec
tests
jlamypoirier Apr 1, 2025
9af26a7
Merge branch 'main' into config_updates
jlamypoirier Apr 1, 2025
9af372d
Tests, fixes, remove tuple format
jlamypoirier Apr 1, 2025
dded00a
fix
jlamypoirier Apr 2, 2025
42d5ca4
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 2, 2025
986f9f3
fix
jlamypoirier Apr 2, 2025
5abc087
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 2, 2025
8e3e795
fixes
jlamypoirier Apr 2, 2025
da6eb7b
fixes
jlamypoirier Apr 3, 2025
67e08aa
Merge branch 'main' into config_updates
jlamypoirier Apr 3, 2025
a09e6f3
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 3, 2025
baad705
fix
jlamypoirier Apr 3, 2025
b702837
Test, fixes
jlamypoirier Apr 5, 2025
a8684f8
Knowledge distillation, fix cross-entropy
jlamypoirier Apr 11, 2025
b781729
Fixes, distillation
jlamypoirier Apr 13, 2025
db6504b
fixes
jlamypoirier Apr 14, 2025
7c2933a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 14, 2025
a017c11
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 14, 2025
368a6bf
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 14, 2025
e0c82a0
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 14, 2025
16a3dd7
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
cff9892
fixes
jlamypoirier Apr 14, 2025
793ecde
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
b67006a
fixes
jlamypoirier Apr 15, 2025
2014108
Add constraints
jlamypoirier Apr 16, 2025
4fb78e4
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 16, 2025
fa3d556
Add constraints
jlamypoirier Apr 16, 2025
6c2c887
Separate reference model preprocessing
jlamypoirier Apr 16, 2025
67f9db6
fix
jlamypoirier Apr 16, 2025
48141e5
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 17, 2025
e6e5a32
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 17, 2025
537deca
fix
jlamypoirier Apr 17, 2025
a590e8b
Merge branch 'distillation' into reference_model_preprocessing
jlamypoirier Apr 17, 2025
3d5dc94
Merge commit '6ad0a96c9328234b907d01a82c4c52bd48752b2f' into update_p…
jlamypoirier Apr 18, 2025
2bb0c08
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 18, 2025
067ba97
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 21, 2025
d2b3154
misc
jlamypoirier Apr 21, 2025
2e63d29
Merge branch 'distillation' into reference_model_preprocessing
jlamypoirier Apr 22, 2025
7133e4d
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier Apr 22, 2025
a0ba051
fixes
jlamypoirier Apr 25, 2025
9ddfb69
add per-layer lr-scale
RaymondLi0 Apr 25, 2025
5e282cc
modeling mtp llamba
oleksost Apr 28, 2025
87b3197
modeling apriel ssm
oleksost Apr 29, 2025
d3e1df2
Apriel to SSM
oleksost Apr 29, 2025
082cf22
Apriel SSM conversion
oleksost Apr 29, 2025
66fb0a2
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier Apr 29, 2025
0d4d5c5
fix
jlamypoirier Apr 29, 2025
b5ffd26
Merge remote-tracking branch 'origin/reference_model_preprocessing' i…
oleksost Apr 29, 2025
c43e535
wip
oleksost Apr 29, 2025
a1f44d4
conversion apriel ssm
oleksost Apr 29, 2025
fbec02d
config apriel
oleksost Apr 29, 2025
75d6460
temp checkpoint conversion
oleksost Apr 29, 2025
73a4252
block pattern for hybrid conversion
oleksost Apr 30, 2025
5afc7dc
SSMBlockType
oleksost Apr 30, 2025
8e9facf
wip
oleksost Apr 30, 2025
77ad39f
add token-prediction loss coefficients
RaymondLi0 Apr 30, 2025
da9bf1a
eval apriel ssm
oleksost May 1, 2025
ac4a598
fix
jlamypoirier May 1, 2025
0c0e7d9
adding check for missing `rope_type` (#246)
nitsanluke May 1, 2025
97ba9d4
Loss masking for distillation
jlamypoirier May 1, 2025
231d5d8
test, misc
jlamypoirier May 1, 2025
d7922af
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 1, 2025
30a75b0
eval apriel ssm
oleksost May 1, 2025
a50bc2e
cleanup
oleksost May 1, 2025
f8af7be
Merge branch 'oleksiy/apriel-ssm' of https://github.com/ServiceNow/Fa…
oleksost May 1, 2025
6532c5f
hybrid config
oleksost May 2, 2025
2a5646b
Merge remote-tracking branch 'origin/distillation_loss_mask' into ole…
oleksost May 2, 2025
9a678df
sft distill
oleksost May 2, 2025
a7abe53
conversion
oleksost May 2, 2025
a68c0b7
conversion
oleksost May 2, 2025
9cfef44
lr stage definition as string
oleksost May 2, 2025
005e623
fixes
jlamypoirier May 2, 2025
cad951a
fix
jlamypoirier May 2, 2025
40970ec
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 2, 2025
bce916d
loss maks
oleksost May 2, 2025
9d95064
fix
jlamypoirier May 2, 2025
2c96abb
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier May 2, 2025
935c470
fix
jlamypoirier May 2, 2025
9aff3b7
fix shuffled tokens
oleksost May 2, 2025
d82ddbf
Merge remote-tracking branch 'origin/main' into reference_model_prepr…
jlamypoirier May 2, 2025
6949c49
Merge branch 'reference_model_preprocessing' into distillation_loss_mask
jlamypoirier May 2, 2025
9c105e7
Merge remote-tracking branch 'origin/main' into distillation_loss_mask
jlamypoirier May 2, 2025
ae4d111
fixes
jlamypoirier May 2, 2025
deb7ce6
fixes
jlamypoirier May 2, 2025
eaba34f
innit like in mamba in llama
oleksost May 2, 2025
f8ca122
embeddings_lr_scale
oleksost May 5, 2025
2db740b
fix
jlamypoirier May 5, 2025
41d4da3
disable freezing
RaymondLi0 May 5, 2025
4160b1f
hybrid model loading and exporting
oleksost May 6, 2025
30ad8b8
wip
oleksost May 7, 2025
ea55ae2
Merge branch 'main' into oleksiy/apriel-ssm
oleksost May 7, 2025
cd4edd5
Merge remote-tracking branch 'origin/distillation_loss_mask' into ole…
oleksost May 7, 2025
9c4f38f
layer-lr scale for mlp as well
RaymondLi0 May 7, 2025
1784dca
wip
oleksost May 7, 2025
1e3cc28
nvm
oleksost May 7, 2025
2dc945b
hybrid modeling
oleksost May 9, 2025
4277e67
modeling
oleksost May 9, 2025
6153c33
Merge branch 'main' into oleksiy/apriel-ssm
oleksost May 9, 2025
c71cb16
nvm
oleksost May 9, 2025
be04c19
output lr scale
oleksost May 9, 2025
1311f5b
output_lr_scale
oleksost May 9, 2025
baf4011
nvm
oleksost May 9, 2025
6cf26c5
eval
oleksost May 10, 2025
901d1b6
rename
oleksost May 12, 2025
b5696fb
Merge remote-tracking branch 'origin/raymond/per_layer_lr_scale' into…
oleksost May 12, 2025
616c540
per_layer_lr_scale
oleksost May 12, 2025
9af5ee5
merged also prediction_loss_coefficient from #243
oleksost May 12, 2025
1a7939b
added logging in mamba
oleksost May 12, 2025
532d0d5
no norm layer freezing
oleksost May 12, 2025
8349130
test
oleksost May 12, 2025
023102c
test
oleksost May 12, 2025
865da95
debug
oleksost May 12, 2025
87c93d3
comment
oleksost May 12, 2025
da4977d
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 12, 2025
a18b80f
debug
oleksost May 12, 2025
40d5437
wip
oleksost May 14, 2025
72ace3b
fix
RaymondLi0 May 14, 2025
121e906
test + comment
oleksost May 14, 2025
8a8fa77
fix
RaymondLi0 May 16, 2025
8e25990
add test with frozen weights
RaymondLi0 May 16, 2025
456a0c5
add description for tests
RaymondLi0 May 16, 2025
87efd45
15b model apriel hybrid
oleksost May 20, 2025
95c7b53
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 20, 2025
326387d
Merge remote-tracking branch 'origin/raymond/fix-frozen-weight' into …
oleksost May 20, 2025
aafbfb5
nvm
oleksost May 20, 2025
c7fe8d7
nvm
oleksost May 20, 2025
848ef04
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost May 20, 2025
c285e8d
nvm
oleksost May 20, 2025
6765d28
hybrid thinker
oleksost May 21, 2025
b0fe37b
nvm
oleksost May 22, 2025
8f84a49
modeling
oleksost May 22, 2025
361aad0
wip
oleksost May 22, 2025
65e466f
nvm
oleksost May 26, 2025
45aa6e4
notebook
oleksost May 29, 2025
d4a04f2
hiddens tate mistral
oleksost Jun 3, 2025
978c0af
update transformers
oleksost Jun 3, 2025
91897db
inference optim
oleksost Jun 5, 2025
9383925
modeling
oleksost Jun 6, 2025
4b643b7
modeling
oleksost Jun 6, 2025
0021ab5
evalchemy
oleksost Jun 6, 2025
cb9a845
tokenizer
oleksost Jun 9, 2025
2424b85
tokenizer
oleksost Jun 9, 2025
e308a2c
skip causal_conv1d
oleksost Jun 9, 2025
3193db9
causal_conv1d optional
oleksost Jun 10, 2025
0f45d4a
5b notebook
oleksost Jun 11, 2025
fc6ed64
nvm
oleksost Jun 11, 2025
5fb3a39
nvm
oleksost Jun 11, 2025
386f3c1
clean
oleksost Jun 11, 2025
73284d8
clean
oleksost Jun 11, 2025
69341d5
Merge remote-tracking branch 'origin/main' into oleksiy/apriel-ssm
oleksost Jun 11, 2025
f56964e
test imports
oleksost Jun 11, 2025
9c50f9c
nvm
oleksost Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions fast_llm/engine/optimizer/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR
begin_step = 0
for stage_arg_str in config.schedule.split(";"):
try:
for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","):
assert begin_step is not None
num_steps = int(num_steps)
end_step = None if num_steps < 0 else begin_step + num_steps
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
if len(stage_args) > 0:
kwargs["end_lr"] = float(stage_args[0])
if len(stage_args) > 1:
kwargs["power"] = float(stage_args[1])
if len(stage_args) > 2:
raise ValueError(stage_args[2:])
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
begin_step = end_step
stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",")
assert begin_step is not None
num_steps = int(num_steps)
end_step = None if num_steps < 0 else begin_step + num_steps
kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)}
if len(stage_args) > 0:
kwargs["end_lr"] = float(stage_args[0])
if len(stage_args) > 1:
kwargs["power"] = float(stage_args[1])
if len(stage_args) > 2:
raise ValueError(stage_args[2:])
stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs))
begin_step = end_step
except Exception:
raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"')
return LearningRateSchedule(stages)
15 changes: 14 additions & 1 deletion fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm


@config_class()
class LLMBlockConfig(BaseModelConfig):
_abstract = False

per_layer_lr_scale: list[float] | None = Field(
default=None,
desc="Custom learning rate scale for each layer.",
doc="May be used to freeze some layers by setting their scale to zero.",
hint=FieldHint.feature,
)


class NormalizationImplementation(str, enum.Enum):
"""
An enum for the available implementations of layer norm.
Expand Down Expand Up @@ -68,7 +80,7 @@ class NormalizationConfig(BaseModelConfig):
valid=check_field(Assert.geq, 0),
)

def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm":
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
from fast_llm.layers.common.normalization import LayerNorm, RMSNorm
from fast_llm.tensor import init_uniform_

Expand All @@ -77,6 +89,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm":
"eps": self.epsilon,
"implementation": self.implementation,
"zero_centered": self.zero_centered,
"lr_scale": lr_scale,
}
if self.initialization_range:
mean = 0 if self.zero_centered else 1
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/layers/common/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
weight_init_method=None,
bias_init_method=init_zeros_,
zero_centered: bool = False,
lr_scale: float | None = None,
):
super().__init__()
assert hidden_dim.parallel_dim is None
Expand Down Expand Up @@ -193,12 +194,14 @@ def __init__(
init_method=weight_init_method,
weight_decay=False,
auto_grad_accumulation=implementation == NormalizationImplementation.torch,
lr_scale=lr_scale,
)
self.bias = ParameterMeta.from_dims(
(hidden_dim,),
init_method=bias_init_method,
weight_decay=False,
auto_grad_accumulation=implementation == NormalizationImplementation.torch,
lr_scale=lr_scale,
)
self.normalized_shape = self.weight.shape

Expand Down Expand Up @@ -236,6 +239,7 @@ def __init__(
implementation: NormalizationImplementation = NormalizationImplementation.auto,
weight_init_method=None,
zero_centered: bool = False,
lr_scale: float | None = None,
):
super().__init__()
assert hidden_dim.parallel_dim is None
Expand Down Expand Up @@ -269,6 +273,7 @@ def __init__(
init_method=weight_init_method,
weight_decay=False,
auto_grad_accumulation=True,
lr_scale=lr_scale,
)
self.normalized_shape = self.weight.shape

Expand Down
23 changes: 23 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ class LanguageModelBaseConfig(BaseModelConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
embeddings_lr_scale: float | None = Field(
default=None,
desc="Learning rate scale for the word embeddings.",
doc="May be used to freeze some layers by setting their scale to zero.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
output_lr_scale: float | None = Field(
default=None,
desc="Custom learning rate scale for the output weights.",
doc="May be used to freeze the output weights by setting their scale to zero.",
hint=FieldHint.feature,
)
prediction_loss_coefficient: list[float] | None = Field(
default=None,
desc="Loss coefficient for each prediction head.",
doc="If not provided, all heads are equally weighted.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
self.transformer.validate()
Expand All @@ -173,6 +192,10 @@ def _validate(self) -> None:
if self.distillation_model is not None:
if self.prediction_heads > 1:
raise NotImplementedError("Multi-token prediction not supported with distillation.")
if isinstance(self.prediction_loss_coefficient, list):
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
for coeff in self.prediction_loss_coefficient:
Assert.geq(coeff, 0)

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
self.transformer.setup_tensor_space(tensor_space)
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
lr_scale=config.embeddings_lr_scale,
)
if self._use_absolute_position_embeddings:
self.position_embeddings_weight = ParameterMeta.from_dims(
Expand All @@ -72,6 +73,7 @@ def __init__(
max_val=config.init_method_max_embed,
),
allow_sequence_tensor_parallel=not config.parallel_embeddings,
lr_scale=config.embeddings_lr_scale,
)

# PEFT.
Expand Down
6 changes: 5 additions & 1 deletion fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __init__(

hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)

self._loss_coefficient = (
config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0
)
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
self._logits_scale_factor = config.logits_scale_factor
Expand Down Expand Up @@ -109,6 +112,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
lr_scale=config.output_lr_scale,
)

def forward(
Expand Down Expand Up @@ -139,7 +143,7 @@ def forward(
else:
if self.training:
# Backward hook to compute the gradient of the loss
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient)
# MTP: Return shared_hidden to be used by the next head.
return shared_hidden

Expand Down
37 changes: 30 additions & 7 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.base_model.config import BaseModelConfig
import enum

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.config import NormalizationConfig
from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig
from fast_llm.utils import Assert


Expand All @@ -20,8 +21,19 @@ class SSMDimNames:
v_heads = "v_heads" # Number of V heads


class SSMBlockType(str, enum.Enum):
"""
An enum for the available mamba types for the MLP layer.
"""

mamba = "m"
mamba2_discrete = "m2d"
mamba2 = "m2"
transformer = "t"


@config_class()
class SSMConfig(BaseModelConfig):
class SSMConfig(LLMBlockConfig):
_abstract = False

# Normalization
Expand Down Expand Up @@ -53,7 +65,8 @@ class SSMConfig(BaseModelConfig):
desc="Whether to use bias in SSM layers",
hint=FieldHint.architecture,
)
dt_rank: int = Field(

dt_rank: None | int = Field(
default=None,
desc="Rank of the Ξ” projection matrix. If 'None', will be set to ceil(hidden_size/16)",
hint=FieldHint.architecture,
Expand Down Expand Up @@ -102,12 +115,22 @@ class SSMConfig(BaseModelConfig):
valid=check_field(Assert.gt, 0),
)

d_inner: None | int = Field(
default=None,
desc="Inner dimension for Mamba2 blocks.",
hint=FieldHint.core,
)
mamba_lr_scale: float | None = Field(
default=None,
desc="Learning rate scale for Mamba blocks.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)

def _validate(self) -> None:
with self._set_implicit_default():
if self.activation_type is None:
self.activation_type = ActivationType.silu
if self.dt_rank is None:
self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation

super()._validate()
Assert.geq(self.dt_max, self.dt_min)
59 changes: 49 additions & 10 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math

import causal_conv1d
import einops
import mamba_ssm.ops.triton.ssd_combined
import torch
Expand All @@ -9,6 +9,16 @@
from fast_llm.layers.common.linear import Linear
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_
from fast_llm.utils import get_lr_scale

logger = logging.getLogger(__name__)

try:
import causal_conv1d
except ImportError:
# this is needed since we cannot use causal_conv1d on B200 GPUs for now
logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead")
causal_conv1d = None

"""
This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
Expand Down Expand Up @@ -44,6 +54,9 @@ def __init__(
bias = config.add_bias_linear
self.layer_idx = layer_idx
self._return_input = return_input
layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None
mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale)
logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}")

td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim)
td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim)
Expand All @@ -67,31 +80,41 @@ def __init__(

# TODO: double check initializations
# Projections
self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size))
self.in_proj = Linear(
td_model,
td_inner_proj,
bias=bias,
weight_init_method=kaiming_init_(td_model.size),
lr_scale=mamba_layer_lr_scale,
)
self.z_bias = (
ParameterMeta.from_dims(
(td_inner,),
weight_decay=False,
init_method=init_zeros_,
lr_scale=mamba_layer_lr_scale,
)
if not bias
else 0.0
)

# Convolutional layer
self.conv1d_weight = ParameterMeta.from_dims(
(td_conv, TensorDim("1", 1), td_conv_kernel),
init_method=init_uniform_(
1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size)
), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
lr_scale=mamba_layer_lr_scale,
)
self.conv1d_bias = ParameterMeta.from_dims(
(td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale
)
self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight))

# D "skip" parameter
self.D = ParameterMeta.from_dims(
(td_n_qk_heads,),
weight_decay=False,
init_method=init_ones_,
lr_scale=mamba_layer_lr_scale,
)

# out_proj
Expand All @@ -100,6 +123,7 @@ def __init__(
td_model,
bias=bias,
weight_init_method=kaiming_init_(td_inner.size),
lr_scale=mamba_layer_lr_scale,
)

@property
Expand Down Expand Up @@ -210,10 +234,25 @@ def forward(self, hidden_states, kwargs):

def convolutional_forward(self, xBC, padded_len):
"""Convolutional layer forward pass for the full sequence."""
xBC = causal_conv1d.causal_conv1d_fn(
xBC.transpose(1, 2),
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
self.conv1d_bias,
activation=None if self.activation_name == "identity" else self.activation_name,
).transpose(1, 2)
if causal_conv1d is None or self.activation_name not in [
"silu",
"swish",
"identity",
]:
xBC = self.act(
torch.nn.functional.conv1d(
xBC.transpose(1, 2),
self.conv1d_weight,
bias=self.conv1d_bias,
groups=self.conv1d_weight.shape[0],
padding=self.conv_kernel_size - 1,
)[..., :padded_len].transpose(1, 2)
)
else:
xBC = causal_conv1d.causal_conv1d_fn(
xBC.transpose(1, 2),
einops.rearrange(self.conv1d_weight, "d 1 w -> d w"),
self.conv1d_bias,
activation=None if self.activation_name == "identity" else self.activation_name,
).transpose(1, 2)
return xBC
Loading