Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
b29d657
Merge remote-tracking branch 'origin/main' into tp_mamba
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class NormalizationConfig(BaseModelConfig):
pass

@abc.abstractmethod
def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module":
pass

@classmethod
Expand All @@ -63,7 +63,7 @@ def _from_dict(
class NoNormalizationConfig(NormalizationConfig):
_abstract = False

def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module":
return torch.nn.Identity()


Expand Down
143 changes: 81 additions & 62 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig
from fast_llm.utils import Assert


class SSMDimNames:
model_dim = "model_dim" # Model dimension (D)
state_dim = "state_dim" # State dimension (N)
conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers
inner_dim = "inner_dim" # Inner dimension after expansion
dt_rank = "dt_rank" # Rank of Ξ”
inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba
inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2
inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2
x_proj_dim = "x_proj_dim" # X projection dimension
head_dim = "head_dim" # Dimension of the mamba2 head (P)
conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers
qk_heads = "qk_heads" # Number of QK heads
v_heads = "v_heads" # Number of V heads

# Mamba 2
x_proj_dim_2 = "x_proj_dim_2" # d_xb
c_heads = "c_heads"
if typing.TYPE_CHECKING:
from fast_llm.tensor import Initializer


class SSMBlockType(enum.StrEnum):
Expand Down Expand Up @@ -53,6 +38,16 @@ def get_mixer_class(self):
raise NotImplementedError(self)


class DTInitType(enum.StrEnum):
constant = "constant"
random = "random"

def get_init_method(self, scale: float) -> "Initializer":
from fast_llm.tensor import init_fill_, init_uniform_centered_

return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale)


@config_class()
class SSMConfig(LLMBlockConfig):
_abstract = False
Expand All @@ -62,106 +57,126 @@ class SSMConfig(LLMBlockConfig):
desc="Configuration for the normalization layers architecture.",
hint=FieldHint.architecture,
)

# Model dimensions
# TODO: Remove (redundant default)
expansion_factor: int = Field(
default=2,
desc="Expansion factor for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# head_size [MambaLayer, Mamba2, DiscreteMamba2]
state_size: int = Field(
default=16,
desc="State size for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# [MambaLayer, Mamba2, DiscreteMamba2]
conv_kernel_dimension: int = Field(
default=4,
desc="Conv kernel dimension for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# Layer parameters
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.architecture,
)

# [MambaLayer, Mamba2]
dt_rank: None | int = Field(
default=None,
desc="Rank of the Ξ” projection matrix. If 'None', will be set to ceil(hidden_size/16)",
hint=FieldHint.architecture,
)
chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.architecture,
)
# head_groups [DiscreteMamba2]
n_qk_heads: int = Field(
default=32,
desc="Number of QK heads for Mamba2 blocks.",
hint=FieldHint.architecture,
)
# heads [DiscreteMamba2]# TODO: Remove? (redundant)
n_v_heads: int = Field(
default=32,
desc="Number of V heads for Mamba2 blocks.",
hint=FieldHint.architecture,
)
activation_type: ActivationType = Field(
# c_size [MambaLayer, Mamba2, DiscreteMamba2]?
d_inner: None | int = Field(
default=None,
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.architecture,
)
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
desc="Inner dimension for Mamba2 blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
# xb_size [Mamba2]
d_xb: int = Field(
default=None,
desc="Dimension of the xB in Mamba2 blocks.",
hint=FieldHint.architecture,
)

d_inner: None | int = Field(
# Model options
# add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer]
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.architecture,
)
# activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2]
activation_type: ActivationType = Field(
default=None,
desc="Inner dimension for Mamba2 blocks.",
hint=FieldHint.core,
hint=FieldHint.architecture,
)
# repeat_xb_before_conv [Mamba2]
repeat_kv_before_conv: bool = Field(
default=True,
desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.",
hint=FieldHint.architecture,
)
# chunk_size [DiscreteMamba2]
chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.architecture,
)

# Learning rate
# lr_scale [MambaLayer, Mamba2, DiscreteMamba2]
mamba_lr_scale: float | None = Field(
default=None,
desc="Learning rate scale for Mamba blocks.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)

# Mamba 2
repeat_kv_before_conv: bool = Field(
default=True,
desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.",
hint=FieldHint.architecture,
# Initialization
# dt_weight_initialization_method [Mamba2]
dt_init: DTInitType = Field(
default=DTInitType.random,
desc="Initialization method for dt",
hint=FieldHint.core,
)
d_xb: int = Field(
default=None,
desc="Dimension of the xB in Mamba2 blocks.",
hint=FieldHint.architecture,
# dt_weight_initialization_scale [Mamba2]
dt_scale: float = Field(
default=1.0,
desc="Scale for dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_init: str = Field(
default="random",
desc="Initialization method for dt",
# dt_bias_initialization_min [MambaLayer, Mamba2]
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
# dt_bias_initialization_max [MambaLayer, Mamba2]
dt_max: float = Field(
default=0.1,
desc="Maximum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_scale: float = Field(
default=1.0,
desc="Scale for dt",
# dt_bias_initialization_floor [MambaLayer, Mamba2]
dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
Expand All @@ -172,3 +187,7 @@ def _validate(self) -> None:
self.activation_type = ActivationType.silu
super()._validate()
Assert.geq(self.dt_max, self.dt_min)

def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None:
# Handled in the model.
pass
Loading