From 82eed2b44c30c891ef2e07c2c80c4f5fcfa1e7f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Jul 2025 17:17:26 -0400 Subject: [PATCH 01/40] TP mamba --- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/ssm/config.py | 214 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 39 ++- fast_llm/layers/ssm/llamba_block.py | 18 +- fast_llm/layers/ssm/mamba2.py | 302 +++++++----------- fast_llm/layers/ssm/mamba_layer.py | 159 ++++----- fast_llm/layers/transformer/attention.py | 3 +- fast_llm/layers/transformer/transformer.py | 27 +- fast_llm/models/custom/model.py | 4 +- fast_llm/models/gpt/model.py | 8 +- fast_llm/models/ssm/config.py | 42 +-- .../external/llamba/modeling_mtp_llamba.py | 10 +- fast_llm/models/ssm/model.py | 34 +- fast_llm/tensor.py | 8 +- setup.cfg | 2 +- tests/test_multi_stage.py | 4 +- 16 files changed, 407 insertions(+), 473 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada389..f4c8067dd 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,35 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" + + composite_heads = "ssm_composite_heads" + composite_heads_and_state = "ssm_composite_heads_and_state" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Inner projection total dimension. + inner_projection = "ssm_inner_projection" + composite_inner_projection = "ssm_composite_inner_projection" + + # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) + conv_dim = "ssm_conv_dim" + + dt_rank = "ssm_dt_rank" - # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim = "x_proj_dim" # X projection dimension + conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers class SSMBlockType(enum.StrEnum): @@ -36,6 +43,16 @@ class SSMBlockType(enum.StrEnum): transformer = "t" +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float): + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + + @config_class() class SSMConfig(LLMBlockConfig): _abstract = False @@ -45,79 +62,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +150,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( default="random", desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +192,59 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + num_heads = div(self.d_inner, self.state_size) + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_head_groups = num_heads + # (head_groups, 2 * group_heads * state_dim) + inner_projection_size = self.d_inner * 2 + elif block_type == SSMBlockType.mamba2: + num_head_groups = div(self.d_xb, self.state_size) + # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) + inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank + elif block_type == SSMBlockType.mamba2_discrete: + Assert.eq(num_heads, self.n_v_heads) + num_head_groups = self.n_qk_heads + # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) + inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim( + group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + inner_projection_size = 2 * num_group_heads * self.state_size + elif block_type == SSMBlockType.mamba2: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + elif block_type == SSMBlockType.mamba2_discrete: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: (head_groups, group_heads + 2, state_size) + tensor_space.add_tensor_dim( + TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ) + + tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..d06b47965 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,8 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def bias_init_method(conv_weight): fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) + return init_uniform_centered_(bound) class DiscreteMamba2(torch.nn.Module): @@ -53,21 +54,20 @@ def __init__( # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() self.config: SSMConfig = config - bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state) + td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) + td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size @@ -85,8 +85,8 @@ def __init__( self.in_proj = Linear( td_model, td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -96,15 +96,13 @@ def __init__( init_method=init_zeros_, lr_scale=mamba_layer_lr_scale, ) - if not bias + if not config.add_bias_linear else 0.0 ) self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -123,12 +121,12 @@ def __init__( self.out_proj = Linear( td_inner, td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ ON variable names and pep8: keeping some variable names as in the original code for clarity. @@ -144,7 +142,6 @@ def forward(self, hidden_states, kwargs): raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - input_ = hidden_states outputs = {} # assert state is None batch, seqlen, dim = input_.shape diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..e877ff9c2 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -14,21 +14,19 @@ class LlambaBlock(BaseBlock): """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, + mixer_cls: type[Mixer], layer_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, layer_index, return_input) + self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def get_mixer(self) -> Mixer: + return self.mixer diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..011889d04 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,14 +1,15 @@ -import math -import typing - -import einops import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -25,25 +26,7 @@ _causal_conv1d_available = False -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -53,207 +36,138 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale - ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.repeat_kv_before_conv = config.repeat_kv_before_conv + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) + self._head_groups = div(self._config.d_xb, self._config.state_size) + self._heads = div(self._config.d_inner, self._config.state_size) + self._group_heads = div(self._heads, self._head_groups) - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + conv1d_dim = ( + inner_dim + if self._config.repeat_kv_before_conv + else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.conv1d_weight = ParameterMeta.from_dims( + (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( - tdt_rank, - td_inner, + tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), ) def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") - - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + assert _causal_conv1d_available + + inner_projection = self.in_proj(hidden_states) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + sequence_length = hidden_states.size(1) + + z, x, b, c, dt = torch.split( + inner_projection, + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + dim=2, + ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + z = z.transpose(1, 2) + + # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) + dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) + out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) + # TODO: Is contiguous needed? + return out.contiguous(), None diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..fa2789b1e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,18 @@ +import logging import math +import typing from typing import Callable -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +21,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -26,169 +32,126 @@ def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor + if tensor.numel() != d_state * d_inner: + raise ValueError(f"_init_A requires not supported for tensor slices.") + return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) return init_ def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + return tensor.add_(torch.log(-torch.expm1(-tensor))) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): def __init__( self, config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): - factory_kwargs = {} super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, - self.conv1d_weight, - self.conv1d_bias, + in_proj, + self.conv1d_weight.unsqueeze(1), + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..76b8ed1ca 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,6 +13,7 @@ TransformerKwargs, TransformerSubLayerName, ) +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -50,7 +51,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..f80e903f0 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -18,13 +18,24 @@ logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" - def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +65,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def get_mixer(self) -> Mixer: pass @torch.compile @@ -115,7 +126,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = self.get_mixer()(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -137,14 +148,14 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - - def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + def get_mixer(self) -> Mixer: + return self.self_attn diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..a9cf3bb8c 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,7 +31,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..a3a68e0a6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..c294fe528 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,12 +6,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,7 +23,7 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( @@ -51,38 +50,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + self.ssm.setup_tensor_space(tensor_space) def _validate(self): with self._set_implicit_default(None): diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 6d9746db1..8f49ded40 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,19 +322,21 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config.d_model, + d_model=self.config._hidden_size, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.input_layernorm = LlamaRMSNorm( + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + ) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config.d_model, + hidden_size=self.config._hidden_size, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..3e57689b6 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -9,7 +9,7 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -39,14 +39,14 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=len(self._config.hybrid_block_layout), @@ -55,8 +55,8 @@ def get_output_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -65,8 +65,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -75,8 +75,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -94,14 +94,14 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -112,8 +112,8 @@ def get_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=i + 1, tensor_space=self._tensor_space, @@ -126,8 +126,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=i + 1, tensor_space=self._tensor_space, @@ -139,8 +139,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=i + 1, tensor_space=self._tensor_space, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b474fe87f 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -354,7 +354,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -def kaiming_init_(d_in): +def init_kaiming_(d_in): return init_normal_(0.0, math.sqrt(2.0 / d_in)) @@ -369,3 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return tensor return init_ + + +def init_uniform_centered_( + high, max_val=None, mean=0.0 +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) diff --git a/setup.cfg b/setup.cfg index 2f69b8e06..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..e5fbc7d69 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +39,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( From 4e310c74634a70c4d8117cc025f18a040ffbd098 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 13:04:54 -0400 Subject: [PATCH 02/40] TP mamba --- fast_llm/engine/config_utils/tensor_space.py | 174 ++++++++++++------- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/ssm/config.py | 45 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 22 ++- fast_llm/layers/ssm/mamba_layer.py | 2 +- fast_llm/tensor.py | 31 ++-- 9 files changed, 184 insertions(+), 108 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..dceeb7da4 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -5,6 +5,8 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed @@ -23,7 +25,7 @@ def __repr__(self) -> str: f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +40,134 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.parallel_group is not None: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + # TODO: Implement + raise NotImplementedError() + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim)[0] + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + if self.is_parallel and expand: + raise NotImplementedError() + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +200,22 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim def get_tensor_dim(self, name: str) -> TensorDim: return self._tensor_dims[name] + + # TODO: Replace uses + __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index f4c8067dd..ce37a9804 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,7 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig @@ -20,8 +20,7 @@ class SSMDimNames: composite_head_groups_and_state = "ssm_composite_head_groups_and_state" # Inner projection total dimension. - inner_projection = "ssm_inner_projection" - composite_inner_projection = "ssm_composite_inner_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) conv_dim = "ssm_conv_dim" @@ -210,7 +209,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType Assert.eq(num_heads, self.n_v_heads) num_head_groups = self.n_qk_heads # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) @@ -219,12 +218,18 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) ) - tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + heads_and_state := CompositeTensorDim( + SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) @@ -234,17 +239,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType if block_type == SSMBlockType.mamba: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) - inner_projection_size = 2 * num_group_heads * self.state_size + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ) elif block_type == SSMBlockType.mamba2: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + ) + ) elif block_type == SSMBlockType.mamba2_discrete: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + ) + ) # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) ) - - tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) - tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d06b47965..988a09504 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -67,7 +67,7 @@ def __init__( td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 011889d04..dff1356e6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -45,6 +45,7 @@ def __init__( inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) + dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) self._head_groups = div(self._config.d_xb, self._config.state_size) self._heads = div(self._config.d_inner, self._config.state_size) @@ -65,13 +66,21 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, weight_init_method=init_kaiming_(hidden_dim.size), lr_scale=lr_scale, ) - self.dt_proj = Linear( - tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, inner_dim, bias=False, # Initialize special dt projection to preserve variance at initialization @@ -110,16 +119,19 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) + dt = self.dt_in_proj(hidden_states) # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) sequence_length = hidden_states.size(1) - z, x, b, c, dt = torch.split( + z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], dim=2, ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) z = z.transpose(1, 2) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index fa2789b1e..0cdcb5242 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -74,7 +74,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b474fe87f..f312f1962 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -166,14 +166,13 @@ def local_to_global( ) -> tuple[torch.Tensor, ...]: # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -187,23 +186,19 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. + # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] - - return tensor_ if expand else tensor_.reshape(self.shape) + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim, expand) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 3cc41182a71d28e02918d76cd882978ca8384f73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 16:57:38 -0400 Subject: [PATCH 03/40] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +- fast_llm/layers/ssm/config.py | 24 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 + fast_llm/layers/ssm/llamba_block.py | 10 +- fast_llm/layers/ssm/mamba_layer.py | 13 ++- fast_llm/layers/transformer/transformer.py | 20 ++-- fast_llm/models/ssm/config.py | 41 +++----- fast_llm/models/ssm/model.py | 99 +++++--------------- fast_llm/tensor.py | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_multi_stage.py | 6 +- tests/utils/model_configs.py | 43 +++++---- 14 files changed, 127 insertions(+), 148 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index dceeb7da4..d927f2e71 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -70,7 +70,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else: return tensor - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -108,7 +108,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -150,7 +150,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() return ( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce37a9804..aa011f75f 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -41,6 +41,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + class DTInitType(enum.StrEnum): constant = "constant" @@ -199,17 +215,13 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: num_head_groups = num_heads - # (head_groups, 2 * group_heads * state_dim) - inner_projection_size = self.d_inner * 2 elif block_type == SSMBlockType.mamba2: num_head_groups = div(self.d_xb, self.state_size) - # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) - inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) + # TODO: Fix (Du einsum crashes) + Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads - # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 988a09504..14fb8aaed 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -216,6 +216,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result + print("AHNFIUWEGIUWEI", self.D.shape, x.shape) + # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index e877ff9c2..774ee7303 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -8,7 +8,7 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -24,9 +24,9 @@ def __init__( layer_index: int, return_input: bool = False, ): - self._debug_mode = self._config_ssm.debug_ssm + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls super().__init__(transformer_config, tensor_space, layer_index, return_input) - self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def get_mixer(self) -> Mixer: - return self.mixer + def _create_mixer(self) -> Mixer: + return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0cdcb5242..8235f4f1a 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,7 +1,6 @@ import logging import math import typing -from typing import Callable import torch @@ -30,21 +29,25 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa # TODO: adopt this initialization to work for tensor parallel setting! if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) + return torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + ) return init_ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) + tensor = ( + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index f80e903f0..a0611cd29 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP @@ -36,6 +35,9 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block with abstract mixer. """ + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" + def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +56,8 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index @@ -65,7 +68,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def get_mixer(self) -> Mixer: + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -126,7 +129,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = self.get_mixer()(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -150,12 +153,15 @@ def forward( class TransformerBlock(BaseBlock): _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - def get_mixer(self) -> Mixer: - return self.self_attn + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention + + return Attention(self._config, self._tensor_space, self._layer_index) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index c294fe528..6b9e28584 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,7 +30,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -43,14 +43,16 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - self.ssm.setup_tensor_space(tensor_space) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -64,30 +66,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): @@ -162,12 +155,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3e57689b6..4a95891a7 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,10 +5,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -53,38 +49,17 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -110,47 +85,19 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f312f1962..1111fd044 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -369,4 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_uniform_centered_( high, max_val=None, mean=0.0 ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e5fbc7d69..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963b..b834ed4d1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -451,16 +451,14 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llamba", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -468,26 +466,31 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid discrete Mamba 2. + "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + # TODO: Set to 16 once fixed. + "model.base_model.ssm.n_qk_heads=32", + "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -497,17 +500,23 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid Mamba 2. + "llama", "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", + "model.base_model.ssm.d_xb=256", ], megatron_args=None, checkpoint_format=None, @@ -517,8 +526,10 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) From 9f7f75c72f1fff36a781773c8c772441d7fa9067 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 19:56:35 -0400 Subject: [PATCH 04/40] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +++++- fast_llm/layers/ssm/config.py | 2 -- fast_llm/layers/ssm/discrete_mamba2.py | 4 +--- fast_llm/layers/ssm/mamba2.py | 19 +++++++++++-------- fast_llm/layers/ssm/mamba_layer.py | 5 ++++- fast_llm/tensor.py | 6 ++++++ tests/utils/model_configs.py | 9 +++++---- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index d927f2e71..2ca7e3e9a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -21,7 +21,7 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," @@ -134,6 +134,8 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: raise NotImplementedError() def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + return ( torch.concatenate( [ @@ -153,6 +155,8 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() + import torch + return ( torch.concatenate( [ diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index aa011f75f..7da4283ba 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -219,8 +219,6 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) - # TODO: Fix (Du einsum crashes) - Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 14fb8aaed..102accb85 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -111,7 +111,7 @@ def __init__( # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (td_n_v_heads,), weight_decay=False, init_method=init_ones_, lr_scale=mamba_layer_lr_scale, @@ -216,8 +216,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result - print("AHNFIUWEGIUWEI", self.D.shape, x.shape) - # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index dff1356e6..11ab91e40 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -62,7 +61,9 @@ def __init__( lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) self.in_proj = OutputParallelLinear( hidden_dim, @@ -124,7 +125,7 @@ def forward(self, hidden_states, kwargs): if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) - sequence_length = hidden_states.size(1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( inner_projection, @@ -177,9 +178,11 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) - # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) - out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: - out = out.transpose(0, 1) - # TODO: Is contiguous needed? - return out.contiguous(), None + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + a, b = self.out_proj(y) + Assert.eq(a.shape, hidden_states.shape) + return a, b diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 8235f4f1a..49b9e45b7 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -35,7 +35,10 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") return torch.log( - torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, ) return init_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 1111fd044..25ae49a31 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -164,6 +164,9 @@ def local_to_global( *, distributed: Distributed, ) -> tuple[torch.Tensor, ...]: + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication is_first_rank, modified = distributed.config.tensor_rank == 0, False @@ -195,6 +198,9 @@ def global_to_local( # Take a trivial slice to convert safetensor slices. tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b834ed4d1..47314263b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -487,9 +487,8 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - # TODO: Set to 16 once fixed. - "model.base_model.ssm.n_qk_heads=32", - "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", "model.base_model.ssm.chunk_size=32", ], megatron_args=None, @@ -503,6 +502,7 @@ def _update_and_add_testing_config( # TODO: Implement ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + compare_factor=2.0, # Micro-sequence split and sequence-first not supported. skip_tests=("sf", "stp", "sdp", "ms"), ) @@ -515,7 +515,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=16", + "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", ], megatron_args=None, @@ -528,6 +528,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), ) From 4054e047d7318c2dfd6e37712f3b6b94d3beca5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 15:22:24 -0400 Subject: [PATCH 05/40] fixes --- fast_llm/engine/config_utils/tensor_space.py | 11 ++++-- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/ssm/mamba2.py | 41 +++++++++++--------- fast_llm/tensor.py | 2 + 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 2ca7e3e9a..0d971a88a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -130,8 +133,10 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - # TODO: Implement - raise NotImplementedError() + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": import torch @@ -139,7 +144,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return ( torch.concatenate( [ - tensor_dim.local_to_global(tensor_, dim)[0] + tensor_dim.local_to_global(tensor_, dim) for tensor_, tensor_dim in zip( tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), self._tensor_dims, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 11ab91e40..a285711c6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,5 @@ +import logging + import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -24,6 +26,8 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False +logger = logging.getLogger(__name__) + class Mamba2(Mixer): """ @@ -43,21 +47,20 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - self._head_groups = div(self._config.d_xb, self._config.state_size) - self._heads = div(self._config.d_inner, self._config.state_size) - self._group_heads = div(self._heads, self._head_groups) + self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size + self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size - conv1d_dim = ( - inner_dim - if self._config.repeat_kv_before_conv - else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - ) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), - init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -69,7 +72,7 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) @@ -77,7 +80,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -129,7 +132,7 @@ def forward(self, hidden_states, kwargs): z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], dim=2, ) @@ -140,28 +143,28 @@ def forward(self, hidden_states, kwargs): x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) b = ( b.transpose(1, 2) - .unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) ) # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) - c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 25ae49a31..6995e9e94 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -184,6 +184,7 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank def global_to_local( @@ -204,6 +205,7 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) + Assert.eq(tensor.shape, self.shape) return tensor @classmethod From 0014cc6b3f79138e53610dc86cb654a5eaba90a0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 18:02:43 -0400 Subject: [PATCH 06/40] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 27 +++----- fast_llm/layers/ssm/llamba_block.py | 11 ++- fast_llm/layers/ssm/mamba2.py | 53 ++++++++++---- fast_llm/layers/ssm/mamba_layer.py | 11 +-- fast_llm/layers/transformer/attention.py | 69 ++++--------------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +-- fast_llm/layers/transformer/transformer.py | 63 ++++++++++++++--- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/ssm/model.py | 8 +-- tests/utils/model_configs.py | 6 +- 12 files changed, 154 insertions(+), 116 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 102accb85..b95ff76da 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -8,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -37,28 +38,23 @@ def bias_init_method(conv_weight): return init_uniform_centered_(bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) td_state = tensor_space.get_tensor_dim(SSMDimNames.state) @@ -223,9 +219,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 774ee7303..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -21,12 +21,17 @@ def __init__( ssm_config: "SSMConfig", tensor_space: "TensorSpace", mixer_cls: type[Mixer], - layer_index: int, + block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: - return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a285711c6..88fe4abc0 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,4 +1,5 @@ import logging +import typing import torch @@ -7,7 +8,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -34,16 +35,31 @@ class Mamba2(Mixer): This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_state, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) @@ -72,7 +88,8 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -80,7 +97,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -91,6 +108,7 @@ def __init__( weight_init_method=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn @@ -116,6 +134,8 @@ def __init__( hidden_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + # TODO: lr_scale? ) def forward(self, hidden_states, kwargs): @@ -123,11 +143,12 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) - dt = self.dt_in_proj(hidden_states) + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( @@ -166,8 +187,15 @@ def forward(self, hidden_states, kwargs): # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) - dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -181,11 +209,12 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() - a, b = self.out_proj(y) - Assert.eq(a.shape, hidden_states.shape) - return a, b + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49b9e45b7..49afa910e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -58,13 +58,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -73,7 +76,7 @@ def __init__( # Tensor dims: inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 76b8ed1ca..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,9 +14,8 @@ TransformerSubLayerName, ) from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -56,6 +55,8 @@ class Attention(Mixer): A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -65,7 +66,7 @@ class Attention(Mixer): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -74,19 +75,9 @@ class Attention(Mixer): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -109,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -179,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -201,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -301,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -342,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -396,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index a0611cd29..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -22,6 +23,15 @@ class Mixer(torch.nn.Module, abc.ABC): Base class for mixer modules. """ + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + @abc.abstractmethod def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -29,6 +39,43 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ in case its addition can be made more efficient in `_bias_dropout_add`. """ + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + class BaseBlock(Layer, abc.ABC): """ @@ -39,7 +86,7 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -48,11 +95,11 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -60,7 +107,7 @@ def __init__( setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -81,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -157,11 +204,11 @@ class TransformerBlock(BaseBlock): _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: from fast_llm.layers.transformer.attention import Attention - return Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index a9cf3bb8c..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -34,7 +34,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a3a68e0a6..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -94,7 +94,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4a95891a7..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -45,7 +45,7 @@ def get_output_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) @@ -55,7 +55,7 @@ def get_output_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, return_input=i != self._config.prediction_heads - 1, ) @@ -79,7 +79,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=i + 1, + block_index=i + 1, tensor_space=self._tensor_space, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 47314263b..4090e5a38 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -517,6 +517,7 @@ def _update_and_add_testing_config( "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, checkpoint_format=None, @@ -530,7 +531,10 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) From 47ad5485454236d557570a32771c5888bbb3658e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:03:01 -0400 Subject: [PATCH 07/40] fixes --- Megatron-LM | 2 +- fast_llm/layers/language_model/head.py | 16 ++++++++++------ fast_llm/logging.py | 2 ++ fast_llm/tensor.py | 3 ++- tests/test_attention.py | 4 ++-- tests/utils/model_configs.py | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 6995e9e94..899e70005 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -205,7 +205,8 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) - Assert.eq(tensor.shape, self.shape) + if not expand: + Assert.eq(tensor.shape, self.shape) return tensor @classmethod diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4090e5a38..18db0d401 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -467,7 +467,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, From 6a074fa3c72bbe16c617a11cff690c543e4c5e86 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:50:05 -0400 Subject: [PATCH 08/40] fixes --- fast_llm/layers/ssm/config.py | 2 +- fast_llm/models/ssm/conversion.py | 18 ++++++---- tests/utils/model_configs.py | 55 ++++++++++++++++--------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7da4283ba..15a6a8210 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -168,7 +168,7 @@ class SSMConfig(LLMBlockConfig): # Initialization # dt_weight_initialization_method [Mamba2] dt_init: DTInitType = Field( - default="random", + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 18db0d401..3ffc3281b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,10 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -467,7 +470,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -477,47 +480,49 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid discrete Mamba 2. + # Tests hybrid Mamba 2. "llama", - "hybrid_discrete_mamba2", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=16", - "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, - # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid Mamba 2. + # Tests hybrid discrete Mamba 2. "llama", - "hybrid_mamba2", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.d_xb=256", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -527,14 +532,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # Micro-sequence split not supported. - skip_tests=( - "sdp", - "ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) From d66651f5433392794d1b45560282d9237824881d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:56:19 -0400 Subject: [PATCH 09/40] Update external --- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: From 50083ba88a0bfa58747d2bc8307814b62af1a79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 15:14:13 -0400 Subject: [PATCH 10/40] SSM debugging --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/head.py | 16 ++- fast_llm/layers/ssm/config.py | 34 +++--- fast_llm/layers/ssm/discrete_mamba2.py | 23 ++-- fast_llm/layers/ssm/llamba_block.py | 29 +++-- fast_llm/layers/ssm/mamba2.py | 38 ++++-- fast_llm/layers/ssm/mamba_layer.py | 36 +++--- fast_llm/layers/transformer/attention.py | 72 +++-------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +- fast_llm/layers/transformer/transformer.py | 94 ++++++++++++--- fast_llm/logging.py | 2 + fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 40 +++---- fast_llm/models/ssm/model.py | 113 +++++------------- setup.cfg | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_attention.py | 4 +- tests/test_multi_stage.py | 8 +- tests/utils/model_configs.py | 1 + 23 files changed, 271 insertions(+), 282 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..a1f357de9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..734e35b21 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..ead32fa2a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +45,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -236,6 +250,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +270,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..a95e94c03 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,4 +1,5 @@ import math +import typing from typing import Callable import einops @@ -7,6 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..9ca0123b2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,9 +40,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -83,6 +81,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +95,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..42252c620 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,6 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 7b32699be7c1a1fb29cc7386eb33280b0bc19a5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:28:56 -0400 Subject: [PATCH 11/40] stuff --- fast_llm/layers/ssm/mamba2.py | 57 ++++++++++++++--------------------- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ead32fa2a..b936ccf14 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ @@ -97,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -110,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -133,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -144,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -171,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -214,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -225,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -238,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + weight=self.conv1d_weight, bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9ca0123b2..b04b1f210 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -78,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42252c620..4976ad2b1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 1feccc866c1dea2da66567476fc911a37a855038 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:48:23 -0400 Subject: [PATCH 12/40] stuff --- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/ssm/mamba_layer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 88fe4abc0..fdba10beb 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -111,7 +111,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49afa910e..11db37910 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -48,9 +48,7 @@ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = ( - tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) - ) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) From e528b50ba5c5e2ea726876779db010f83fccd8ef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:00:20 -0400 Subject: [PATCH 13/40] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 12 ++++++++---- fast_llm/layers/ssm/mamba_layer.py | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b95ff76da..fdce9bf63 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs @@ -97,7 +97,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index fdba10beb..8be9dcb9b 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -75,7 +75,11 @@ def __init__( conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( - (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) @@ -168,9 +172,9 @@ def forward(self, hidden_states, kwargs): .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 11db37910..07eec38e6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -87,7 +87,11 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + ( + inner_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) @@ -146,7 +150,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight.unsqueeze(1), + self.conv1d_weight, None, self.x_proj.weight, self.dt_proj_weight, From b49c42febac4f32dc1be83655b242d6199a385bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:16:42 -0400 Subject: [PATCH 14/40] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 8 ++++---- fast_llm/layers/ssm/mamba_layer.py | 4 ++-- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ tests/utils/model_configs.py | 1 - 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 734e35b21..c0ae7e781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b936ccf14..74c212add 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -225,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=self.conv1d_weight, + weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index a95e94c03..4493332ce 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig @@ -98,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4976ad2b1..1eee3675d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,6 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From c1b7f44a10ff379a067b10b76df296f3bee4cac1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:19:08 -0400 Subject: [PATCH 15/40] misc --- .../models/ssm/external/llamba/modeling_mtp_llamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 8f49ded40..6d9746db1 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,21 +322,19 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config._hidden_size, + d_model=self.config.d_model, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs - ) + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config._hidden_size, + hidden_size=self.config.d_model, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) From 31f5d415ef0c7eeca54a26d415076cbf3ba33cfd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:20:26 -0400 Subject: [PATCH 16/40] misc --- fast_llm/models/ssm/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) From 0a9ff25f6e0a699caef881dfcaeef0b19f825764 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:22:24 -0400 Subject: [PATCH 17/40] misc --- fast_llm/models/ssm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 6b9e28584..d2a69303c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -40,9 +40,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) From e7d9636819ab83df7204cc2b021fd4565188e946 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 19:55:53 -0400 Subject: [PATCH 18/40] Parallel discrete mamba 2 --- fast_llm/layers/ssm/config.py | 12 +- fast_llm/layers/ssm/discrete_mamba2.py | 212 ++++++++++--------------- fast_llm/layers/ssm/mamba2.py | 6 +- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 15a6a8210..7f0b3cf61 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -211,23 +211,25 @@ def _validate(self) -> None: def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - num_heads = div(self.d_inner, self.state_size) # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) num_head_groups = num_heads elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: - Assert.eq(num_heads, self.n_v_heads) + # TODO: Use different variables? + num_heads = self.n_v_heads num_head_groups = self.n_qk_heads + # v_heads have size `headdim` that may be different from `state_size`. + Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim( - group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) - ) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fdce9bf63..ac4fb87cc 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,12 +1,12 @@ import logging -import math import typing import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -32,12 +32,6 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_centered_(bound) - - class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" @@ -51,198 +45,162 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config + self._config: SSMConfig = config layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state) - td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + + self._local_heads = heads_dim.size + self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_inner_size = inner_dim.size + self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not config.add_bias_linear - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_v_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ if kwargs[TransformerKwargs.sequence_first]: raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # Project input - xBCzA_log = self.in_proj(u) + inner_projection = self.in_proj(input_) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + xBC = self.convolutional_forward(xBC, sequence_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( + y = _mamba_chunk_scan_combined( x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + return self.out_proj(y) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, "swish", - "identity", + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 8be9dcb9b..cba28f8b8 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -142,12 +142,12 @@ def __init__( # TODO: lr_scale? ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available - inner_projection = self.in_proj(hidden_states) - dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From c14b7643ae3f840f8da23404922f9482ff507284 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 17:14:17 -0400 Subject: [PATCH 19/40] Mamba 2, misc --- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/ssm/config.py | 62 ++++++++++--------- fast_llm/layers/ssm/discrete_mamba2.py | 50 ++++++++++----- fast_llm/layers/ssm/mamba2.py | 22 ++++--- fast_llm/layers/ssm/mamba_layer.py | 27 ++++----- fast_llm/tensor.py | 74 +++++++++++++++-------- tests/models/test_checkpoint.py | 11 +++- tests/utils/model_configs.py | 9 +-- 8 files changed, 160 insertions(+), 100 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7f0b3cf61..c06d85148 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,31 +5,31 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels - + head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions composite_heads = "ssm_composite_heads" - composite_heads_and_state = "ssm_composite_heads_and_state" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - # Inner projection total dimension. + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" concatenated_inner_projection = "ssm_concatenated_inner_projection" - # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) - conv_dim = "ssm_conv_dim" - - dt_rank = "ssm_dt_rank" - - x_proj_dim = "x_proj_dim" # X projection dimension - conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers - class SSMBlockType(enum.StrEnum): """ @@ -62,7 +62,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float): + def get_init_method(self, scale: float) -> Initializer: from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) @@ -222,56 +222,64 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # TODO: Use different variables? num_heads = self.n_v_heads num_head_groups = self.n_qk_heads - # v_heads have size `headdim` that may be different from `state_size`. - Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) - tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - heads_and_state := CompositeTensorDim( - SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) tensor_space.add_tensor_dim( head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + SSMDimNames.composite_head_groups_and_state, (head_groups, state) ) ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) # DT projection if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) # TODO: Use composition instead tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) ) elif block_type == SSMBlockType.mamba2: # TODO: Factor out state? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), ) ) - # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( - TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ac4fb87cc..64377b93c 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,14 +49,18 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - self._local_heads = heads_dim.size + # local_head_groups = head_groups / TP self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations @@ -80,7 +84,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -107,24 +111,25 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - assert _mamba_available - sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, padded_sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -134,9 +139,13 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, sequence_length) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, @@ -148,13 +157,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=-1, ) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward y = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, A=-torch.ones(self._local_heads, device=A_log.device), @@ -169,23 +181,31 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if not self._config.add_bias_linear: z = z + self.z_bias - # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) + def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" if _causal_conv1d_available and self._config.activation_type in ( ActivationType.silu, - "swish", ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, activation=( None diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index cba28f8b8..1ae25e44c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -39,7 +39,7 @@ class Mamba2(Mixer): _XZ_DIMS = ( TransformerDimNames.batch, - SSMDimNames.composite_heads_and_state, + SSMDimNames.composite_heads_and_head_dim, TransformerDimNames.sequence_q, ) _BC_DIMS = ( @@ -62,7 +62,7 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) @@ -78,7 +78,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -146,6 +146,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _mamba_available assert _causal_conv1d_available + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) @@ -161,10 +163,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=2, ) - # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) z = z.transpose(1, 2) - # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( @@ -172,16 +174,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) b = ( b.transpose(1, 2) .unflatten(1, (self._local_head_groups, self._config.state_size)) @@ -216,9 +218,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) - # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 07eec38e6..64c8227fc 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -29,30 +29,27 @@ """ -def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # TODO: adopt this initialization to work for tensor parallel setting! +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa if tensor.numel() != d_state * d_inner: - raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log( + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) .expand(d_inner, d_state), out=tensor, ) - return init_ + return LambdaInitializer(init_, requires_global_initialization=True) -def init_dtprojbias( - dt_max: float, dt_min: float, dt_init_floor: float -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - return tensor.add_(torch.log(-torch.expm1(-tensor))) + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) class MambaLayer(Mixer): @@ -72,7 +69,7 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -90,7 +87,7 @@ def __init__( ( inner_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -98,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 899e70005..b89ed4a04 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import abc import functools import math import typing @@ -241,7 +242,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -251,7 +252,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -276,7 +281,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -303,6 +308,10 @@ def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -334,11 +343,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False - return init_ + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -346,38 +376,32 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_kaiming_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_uniform_centered_( - high, max_val=None, mean=0.0 -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: return init_uniform_( mean - high, mean + high, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..4bda5512c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 038b53c26..722d8d63a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) @@ -540,19 +541,19 @@ def _update_and_add_testing_config( "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + skip_tests=("sdp", "ms"), ) From b605bd29bcdd85379a2c43124f07a4c215f53e71 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 18:24:53 -0400 Subject: [PATCH 20/40] doc --- docs/contributing/contributing.md | 4 ++-- docs/contributing/testing.md | 37 ++++++++++++++++++++++++++----- mkdocs.yaml | 1 + 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/docs/contributing/contributing.md b/docs/contributing/contributing.md index 6185b63fe..938fe925f 100644 --- a/docs/contributing/contributing.md +++ b/docs/contributing/contributing.md @@ -40,7 +40,7 @@ Before diving into code, [open an issue](https://github.com/ServiceNow/Fast-LLM/ Here are some tips to ensure your pull request gets reviewed and merged promptly: - **Follow our coding standards**: Stick to our [style guide and conventions](https://servicenow.github.io/Fast-LLM/developers/style-guide) to keep the code clean and consistent. -- **Write tests**: Verify your changes with unit tests for new features or bug fixes. +- **Write tests**: Verify your changes with unit tests for new features or bug fixes. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for tips and recommendations on testing. - **Test on GPUs and real-world workloads**: Since Fast-LLM is all about training large language models, make sure your changes work smoothly in GPU environments and on typical training setups. - **Run benchmarks and performance tests**: Make sure your changes don't slow things down. If there's any impact on performance, provide benchmark results to back it up. - **Avoid introducing new issues**: Check that there are no new runtime warnings, type checker errors, linting problems, or unhandled edge cases. @@ -48,7 +48,7 @@ Here are some tips to ensure your pull request gets reviewed and merged promptly - **Keep sensitive data out**: Make sure your code or commit messages don't expose private or proprietary information. - **Use a clear and descriptive title**: The PR title should summarize the key change or feature introduced. Avoid vague titles like "Fix bug" or "Update code." Start with a keyword like `[feat]`, `[fix]`, `[docs]`, etc. to categorize the change. Reference the issue number if applicable (e.g., `[fix] resolve #123 memory leak in training loop`). This title will become the commit message for the squashed merge. - **Use the [PR template](https://github.com/ServiceNow/Fast-LLM/blob/main/.github/PULL_REQUEST_TEMPLATE.md)**: Complete the checklist to make sure everything is in order before hitting submit. -- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. +- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for more details on testing and debugging. ## 🆘 Seeking Help or Clarification diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index 8df93f9d0..9cce78e3c 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -1,13 +1,43 @@ --- -title: Writing tests +title: Writing and running tests --- +## Debugging with tests + +### Selecting tests + +When debugging, it is often practical to target specific tests that will run quickly. While Pytest supports targeting specific directory, files or tests, the complex parameterization and dependencies of our tests often makes explicit targeting tedious and/or impractical. We provide several options for selecting tests: + +* `--skip-slow`: This will run a subset of "fast" tests that cover the majority of our codebase. This is useful for quickly checking that changes did not break Fast-LLM too badly before running the full test suite. Note that parallel testing (`-n`) is not needed (and may be counter-productive) with this argument. +* `--run-extra-slow`: Some tests are disabled by default because they take too long to run (ex. complex integration tests) and/or are not particularly important. This argument re-enables them. +* `--models MODEL0 MODEL1 ...`: This allows targeting one or more specific models from the model tests (see below), and is particularly useful when debugging a model. For example, `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will test checkpoints specifically for the llama model. (Note that `-n` may not be needed here as model tests for a given model are only partly distributed dure to dependency constraints.) + +### Monitoring distributed tests + +`--no-distributed-capture` + +### Other options + +* `--show-gpu-memory N`: Our testing suite monitors GPU memory usage and reports the highest users. Use this option to adjust the number of reported tests (10 by default). Note that this option is mainly intended to make sure tests don't use too much memory (which could cause crashes with lots of parallel tests) and may not be an accurate measurement. +* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. This option may be used to show them explicitly. + +## Best practices + ## Testing models [Model integration tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/models) are the most important part of our testing suite, ensuring that Fast-LLM works and yields consistent results for a variety of models, training configurations, optimizations, etc. For each tested model, we run a series of tests divided into several groups. Much of these tests consist of running a short Fast-LLM training run, then comparing intermediate tensors (ex. parameter initialization, layer outputs and gradients, parameter gradients) against a baseline. +### What is being tested + +Coming soon. + +!!! warning "Don't forget about unit tests!" + + While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. + The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. + ### Adding a model When adding support for a new model that comes with additional features, the simplest option to increase coverage is to add an example configuration to the [tested modelsl](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/utils/model_configs.py). @@ -41,11 +71,6 @@ _update_and_add_testing_config( ) ``` -!!! warning "Don't forget about unit tests!" - - While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. - The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. - #### Reference for groups Fast-LLM currently supports the following testing groups: diff --git a/mkdocs.yaml b/mkdocs.yaml index 85fd4bff0..00e52a011 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -189,5 +189,6 @@ nav: - Contribution Guide: contributing/contributing.md - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md + - Testing: contributing/testing.md - About Us: about-us.md - Join Us: join-us.md From 5eea938403a74bcf8ee7f0c504e3d8bb6fe118f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 11:52:24 -0400 Subject: [PATCH 21/40] fix --- fast_llm/models/custom/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], From 2e6d082e4b2d7fc3f043365664339a5b823713e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 15:46:06 -0400 Subject: [PATCH 22/40] fixes --- fast_llm/engine/config_utils/tensor_space.py | 42 +++++++++++++- fast_llm/engine/multi_stage/fsdp.py | 32 +++-------- fast_llm/layers/ssm/config.py | 7 ++- fast_llm/models/gpt/megatron.py | 29 +++++----- fast_llm/tensor.py | 58 ++++++++++++++------ 5 files changed, 109 insertions(+), 59 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0d971a88a..55d87e271 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,13 +66,23 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - if self.parallel_group is not None: + if self.is_parallel: from fast_llm.core.ops import gather_op return gather_op(tensor, self.parallel_group, dim) else: return tensor + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] @@ -111,6 +121,15 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): @@ -157,6 +176,27 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c06d85148..9b0949d55 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,13 +1,16 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. @@ -62,7 +65,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float) -> Initializer: + def get_init_method(self, scale: float) -> "Initializer": from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b89ed4a04..0637931ee 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,5 +1,6 @@ import abc import functools +import logging import math import typing @@ -13,6 +14,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -159,12 +162,11 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ if tensor.ndim == 0: tensor = tensor[None] Assert.eq(tensor.shape, self.shape) @@ -188,14 +190,32 @@ def local_to_global( Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + logger.info( + f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" + ) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. tensor = tensor[:] @@ -205,9 +225,9 @@ def global_to_local( Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): - tensor = tensor_dim.global_to_local(tensor, dim, expand) - if not expand: - Assert.eq(tensor.shape, self.shape) + tensor = tensor_dim.global_to_local(tensor, dim) + + Assert.eq(tensor.shape, self.shape) return tensor @classmethod @@ -302,7 +322,11 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator From b6c86138bbdbf19099b799475f17e8d3dcca34b6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 16:08:58 -0400 Subject: [PATCH 23/40] misc --- fast_llm/engine/config_utils/tensor_space.py | 5 +---- fast_llm/layers/language_model/embedding.py | 8 ++++---- fast_llm/layers/language_model/head.py | 10 +++++----- .../layers/language_model/preprocessing.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 18 ++++++++--------- fast_llm/layers/ssm/mamba2.py | 20 +++++++++---------- fast_llm/layers/ssm/mamba_layer.py | 16 +++++++-------- fast_llm/layers/transformer/attention.py | 16 +++++++-------- .../layers/transformer/mixture_of_experts.py | 6 +++--- fast_llm/layers/transformer/mlp.py | 6 +++--- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 ++-- fast_llm/layers/transformer/rotary/rotary.py | 4 ++-- fast_llm/layers/transformer/transformer.py | 4 ++-- fast_llm/models/gpt/model.py | 2 +- fast_llm/tensor.py | 2 +- 16 files changed, 62 insertions(+), 65 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 55d87e271..cf2974a99 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -263,8 +263,5 @@ def add_tensor_dim(self, tensor_dim: TensorDim) -> None: ) self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] - - # TODO: Replace uses - __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 64377b93c..c9d555de9 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,25 +49,25 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) - heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] + heads_dim = tensor_space[SSMDimNames.composite_heads] # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size + self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -83,8 +83,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1ae25e44c..77c1b3869 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -62,13 +62,13 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) - xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + dt_rank_dim = tensor_space[SSMDimNames.dt_rank] - self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size - self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -77,8 +77,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +90,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +122,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 64c8227fc..9343ef1b8 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -69,8 +69,8 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -78,7 +78,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -86,8 +86,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), + tensor_space[SSMDimNames.concatenated_x_projection], weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +104,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + (inner_dim, tensor_space[SSMDimNames.dt_rank]), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +116,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 0637931ee..b3795b740 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -150,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( From e536af9d935fe789b98683777e3e320eaf5d7e62 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 16:15:17 -0400 Subject: [PATCH 24/40] Concatenated dim --- fast_llm/engine/config_utils/tensor_space.py | 224 +++++++++++++----- fast_llm/engine/multi_stage/fsdp.py | 32 +-- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/language_model/embedding.py | 8 +- fast_llm/layers/language_model/head.py | 10 +- .../layers/language_model/preprocessing.py | 4 +- fast_llm/layers/transformer/attention.py | 16 +- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 6 +- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 +- fast_llm/layers/transformer/rotary/rotary.py | 4 +- fast_llm/layers/transformer/transformer.py | 4 +- fast_llm/models/gpt/megatron.py | 29 +-- fast_llm/models/gpt/model.py | 2 +- fast_llm/tensor.py | 169 ++++++++----- 20 files changed, 346 insertions(+), 201 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..cf2974a99 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,9 +6,13 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +43,180 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +249,19 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b3795b740 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,17 +1,21 @@ +import abc import functools +import logging import math import typing import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -146,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -158,22 +162,23 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +187,48 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + logger.info( + f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" + ) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim) - return tensor_ if expand else tensor_.reshape(self.shape) + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +262,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +272,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +301,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -293,12 +322,20 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +367,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + - return init_ +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +400,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) From 017f5cc5a021d9a2ef58e5d1903f60c4917f311c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 18:09:53 -0400 Subject: [PATCH 25/40] fixes --- fast_llm/layers/ssm/discrete_mamba2.py | 24 ++++++++++----------- fast_llm/layers/ssm/mamba2.py | 26 +++++++++++----------- fast_llm/layers/ssm/mamba_layer.py | 30 ++++++++++++-------------- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c0ae7e781..6012f74a7 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -62,14 +62,14 @@ def __init__( mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv = tensor_space[SSMDimNames.conv_dim] + td_n_qk_heads = tensor_space[SSMDimNames.qk_heads] + td_n_v_heads = tensor_space[SSMDimNames.v_heads] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2] self.d_model = td_model.size self.d_inner = td_inner.size @@ -88,7 +88,7 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 @@ -126,7 +126,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 74c212add..9dfad8462 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_ from fast_llm.utils import get_lr_scale try: @@ -80,13 +80,13 @@ def __init__( self.config.mamba_lr_scale, layer_lr_scale ) - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim] + td_state: TensorDim = tensor_space[SSMDimNames.state_dim] + td_model: TensorDim = tensor_space[SSMDimNames.model_dim] + tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank] + td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2] + td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2] + td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size] self.repeat_kv_before_conv = config.repeat_kv_before_conv @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -131,14 +131,14 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.dt_in_proj = Linear( td_model, tdt_rank, bias=config.add_bias_linear, - weight_init_method=kaiming_init_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=mamba_layer_lr_scale, ) # Initialize special dt projection to preserve variance at initialization @@ -185,7 +185,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), ) def forward(self, hidden_states, kwargs): diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4493332ce..5e0ae786e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import get_lr_scale try: @@ -75,15 +75,13 @@ def __init__( self.config: SSMConfig = config # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_mamba] # TensorDim("D_inner_2", self.d_inner * 2) + tdt_rank = tensor_space[SSMDimNames.dt_rank] + td_x_proj = tensor_space[SSMDimNames.x_proj_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size @@ -94,12 +92,12 @@ def __init__( self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + init_method=init_kaiming_(td_model.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), + init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) @@ -111,7 +109,7 @@ def __init__( self.x_proj = Linear( td_inner, td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, ) @@ -120,7 +118,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), + init_method=init_kaiming_(tdt_rank.size), lr_scale=mamba_layer_lr_scale, ) @@ -151,7 +149,7 @@ def __init__( td_inner, td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True From c41efc21ae2f8c1a87d35834a28ae3ad852f22d4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 19:24:46 -0400 Subject: [PATCH 26/40] doc --- docs/contributing/testing.md | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index 9cce78e3c..e04cf1b37 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -6,20 +6,22 @@ title: Writing and running tests ### Selecting tests -When debugging, it is often practical to target specific tests that will run quickly. While Pytest supports targeting specific directory, files or tests, the complex parameterization and dependencies of our tests often makes explicit targeting tedious and/or impractical. We provide several options for selecting tests: +When debugging, it is often advisable to target specific tests that can be executed efficiently. Although Pytest allows targeting specific tests or files, complex parameterization and dependencies in our suite often make explicit selection difficult. To address this, several options for test selection are available: -* `--skip-slow`: This will run a subset of "fast" tests that cover the majority of our codebase. This is useful for quickly checking that changes did not break Fast-LLM too badly before running the full test suite. Note that parallel testing (`-n`) is not needed (and may be counter-productive) with this argument. -* `--run-extra-slow`: Some tests are disabled by default because they take too long to run (ex. complex integration tests) and/or are not particularly important. This argument re-enables them. -* `--models MODEL0 MODEL1 ...`: This allows targeting one or more specific models from the model tests (see below), and is particularly useful when debugging a model. For example, `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will test checkpoints specifically for the llama model. (Note that `-n` may not be needed here as model tests for a given model are only partly distributed dure to dependency constraints.) +* `--skip-slow`: Executes a subset of expedited tests that encompass much of the codebase. This option is effective for quickly checking for major regressions prior to executing the comprehensive test suite. Please note, parallel testing (`-n`) is typically unnecessary—and may even be counterproductive—when using this argument. +* `--run-extra-slow`: Certain tests are disabled by default due to their lengthy execution times (e.g., complex integration tests) or limited criticality. Use this flag to re-enable them. +* `--models MODEL0 MODEL1 ...`: Enables targeting of one or more specific models within the model testing suite. This feature is particularly useful during model-specific debugging efforts. For instance, running `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will specifically test checkpointing functionality for the llama model. Note that parallelization (`-n`) may be unnecessary in this context, as model tests for a given model are only partially distributed due to dependency constraints. ### Monitoring distributed tests -`--no-distributed-capture` +Distributed tests are generally the slowest due to the overhead associated with starting processes and process groups. To mitigate this, Fast-LLM incorporates several bundled tests that execute multiple subtests within a single subprocess call. As bundled calls can generate substantial output and potentially reduce report readability, Fast-LLM captures the output from each subtest and forwards it to an associated test. If necessary, this output capture can be disabled using `--no-distributed-capture`—for instance, if a severe crash hinders output capture or to disable pytest capture entirely (`-s`). Captured logs are stored in the testing cache directory; please consult individual tests for specific locations. + +For example, `test_run_model_distributed[llama]` tries various distributed configurations for the `llama` model, each reported under an associated test such as `test_model_distributed[llama-distributed]`. Should a distributed subtest, say `tp2` (tensor-parallel), encounter a failure, `test_run_model_distributed` will log the issue, continue executing remaining subtests, and ultimately raise an error to designate the bundled test as failed. The associated test, `test_model_distributed[llama-tp2]`, will also fail and display the captured output (retrieved from `/tmp/fast_llm_tests/models/llama/tp2/`), separated by type (stdout, stderr and traceback) as would happen for a normal test (minus some advanced formating), but also by rank. ### Other options -* `--show-gpu-memory N`: Our testing suite monitors GPU memory usage and reports the highest users. Use this option to adjust the number of reported tests (10 by default). Note that this option is mainly intended to make sure tests don't use too much memory (which could cause crashes with lots of parallel tests) and may not be an accurate measurement. -* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. This option may be used to show them explicitly. +* `--show-gpu-memory N`: Monitors GPU memory use and reports the top N tests (default 10). Mainly helps ensure tests don't exceed memory limits, but results may not be precise. +* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. Use this flag to display them. ## Best practices @@ -29,15 +31,6 @@ When debugging, it is often practical to target specific tests that will run qui For each tested model, we run a series of tests divided into several groups. Much of these tests consist of running a short Fast-LLM training run, then comparing intermediate tensors (ex. parameter initialization, layer outputs and gradients, parameter gradients) against a baseline. -### What is being tested - -Coming soon. - -!!! warning "Don't forget about unit tests!" - - While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. - The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. - ### Adding a model When adding support for a new model that comes with additional features, the simplest option to increase coverage is to add an example configuration to the [tested modelsl](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/utils/model_configs.py). @@ -71,6 +64,11 @@ _update_and_add_testing_config( ) ``` +!!! warning "Don't forget about unit tests!" + + While adding a model is a quick and efficient way to increase coverage, it is **not a replacement for unit tests**. + The model testing suite performs intensive consistency checks, but does little to make sure those results are correct to begin with. See [functional tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/functional) and [test_lm_head](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py) for good examples of unit tests for individual components and an entire layer. + #### Reference for groups Fast-LLM currently supports the following testing groups: From 0b8bd5dc7a09d73adc2fe08a1aa2924052bd01b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 19:38:26 -0400 Subject: [PATCH 27/40] cleanup --- docs/contributing/contributing.md | 4 ++-- docs/contributing/testing.md | 25 +------------------------ mkdocs.yaml | 1 - setup.cfg | 2 +- 4 files changed, 4 insertions(+), 28 deletions(-) diff --git a/docs/contributing/contributing.md b/docs/contributing/contributing.md index 938fe925f..6185b63fe 100644 --- a/docs/contributing/contributing.md +++ b/docs/contributing/contributing.md @@ -40,7 +40,7 @@ Before diving into code, [open an issue](https://github.com/ServiceNow/Fast-LLM/ Here are some tips to ensure your pull request gets reviewed and merged promptly: - **Follow our coding standards**: Stick to our [style guide and conventions](https://servicenow.github.io/Fast-LLM/developers/style-guide) to keep the code clean and consistent. -- **Write tests**: Verify your changes with unit tests for new features or bug fixes. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for tips and recommendations on testing. +- **Write tests**: Verify your changes with unit tests for new features or bug fixes. - **Test on GPUs and real-world workloads**: Since Fast-LLM is all about training large language models, make sure your changes work smoothly in GPU environments and on typical training setups. - **Run benchmarks and performance tests**: Make sure your changes don't slow things down. If there's any impact on performance, provide benchmark results to back it up. - **Avoid introducing new issues**: Check that there are no new runtime warnings, type checker errors, linting problems, or unhandled edge cases. @@ -48,7 +48,7 @@ Here are some tips to ensure your pull request gets reviewed and merged promptly - **Keep sensitive data out**: Make sure your code or commit messages don't expose private or proprietary information. - **Use a clear and descriptive title**: The PR title should summarize the key change or feature introduced. Avoid vague titles like "Fix bug" or "Update code." Start with a keyword like `[feat]`, `[fix]`, `[docs]`, etc. to categorize the change. Reference the issue number if applicable (e.g., `[fix] resolve #123 memory leak in training loop`). This title will become the commit message for the squashed merge. - **Use the [PR template](https://github.com/ServiceNow/Fast-LLM/blob/main/.github/PULL_REQUEST_TEMPLATE.md)**: Complete the checklist to make sure everything is in order before hitting submit. -- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. See our [testing guide](https://servicenow.github.io/Fast-LLM/contributing/testing) for more details on testing and debugging. +- **Make sure all tests pass before merging**: Run the tests with `pytest tests/ -v -ra -n 10`, and fix any failure before merging. If possible, please run the test in an environment with at least 4 GPUs. ## 🆘 Seeking Help or Clarification diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md index e04cf1b37..8df93f9d0 100644 --- a/docs/contributing/testing.md +++ b/docs/contributing/testing.md @@ -1,30 +1,7 @@ --- -title: Writing and running tests +title: Writing tests --- -## Debugging with tests - -### Selecting tests - -When debugging, it is often advisable to target specific tests that can be executed efficiently. Although Pytest allows targeting specific tests or files, complex parameterization and dependencies in our suite often make explicit selection difficult. To address this, several options for test selection are available: - -* `--skip-slow`: Executes a subset of expedited tests that encompass much of the codebase. This option is effective for quickly checking for major regressions prior to executing the comprehensive test suite. Please note, parallel testing (`-n`) is typically unnecessary—and may even be counterproductive—when using this argument. -* `--run-extra-slow`: Certain tests are disabled by default due to their lengthy execution times (e.g., complex integration tests) or limited criticality. Use this flag to re-enable them. -* `--models MODEL0 MODEL1 ...`: Enables targeting of one or more specific models within the model testing suite. This feature is particularly useful during model-specific debugging efforts. For instance, running `pytest tests/models/test_models/test_checkpoint.py -v -ra --models llama` will specifically test checkpointing functionality for the llama model. Note that parallelization (`-n`) may be unnecessary in this context, as model tests for a given model are only partially distributed due to dependency constraints. - -### Monitoring distributed tests - -Distributed tests are generally the slowest due to the overhead associated with starting processes and process groups. To mitigate this, Fast-LLM incorporates several bundled tests that execute multiple subtests within a single subprocess call. As bundled calls can generate substantial output and potentially reduce report readability, Fast-LLM captures the output from each subtest and forwards it to an associated test. If necessary, this output capture can be disabled using `--no-distributed-capture`—for instance, if a severe crash hinders output capture or to disable pytest capture entirely (`-s`). Captured logs are stored in the testing cache directory; please consult individual tests for specific locations. - -For example, `test_run_model_distributed[llama]` tries various distributed configurations for the `llama` model, each reported under an associated test such as `test_model_distributed[llama-distributed]`. Should a distributed subtest, say `tp2` (tensor-parallel), encounter a failure, `test_run_model_distributed` will log the issue, continue executing remaining subtests, and ultimately raise an error to designate the bundled test as failed. The associated test, `test_model_distributed[llama-tp2]`, will also fail and display the captured output (retrieved from `/tmp/fast_llm_tests/models/llama/tp2/`), separated by type (stdout, stderr and traceback) as would happen for a normal test (minus some advanced formating), but also by rank. - -### Other options - -* `--show-gpu-memory N`: Monitors GPU memory use and reports the top N tests (default 10). Mainly helps ensure tests don't exceed memory limits, but results may not be precise. -* `--show-skipped`: Many tests skipped for obvious reasons (ex. marked as slow or extra slow, skipped model testing groups (see below)) are removed entirely from the report to reduce clutter. Use this flag to display them. - -## Best practices - ## Testing models [Model integration tests](https://github.com/ServiceNow/Fast-LLM/blob/main/tests/models) are the most important part of our testing suite, ensuring that Fast-LLM works and yields consistent results for a variety of models, training configurations, optimizations, etc. diff --git a/mkdocs.yaml b/mkdocs.yaml index 00e52a011..85fd4bff0 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -189,6 +189,5 @@ nav: - Contribution Guide: contributing/contributing.md - Style Guide: contributing/style-guide.md - Development Practices: contributing/dev-practices.md - - Testing: contributing/testing.md - About Us: about-us.md - Join Us: join-us.md diff --git a/setup.cfg b/setup.cfg index dc6d0c445..843aa15ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 cartesia_pytorch>=0.0.2 From 6bf06d6aecb9a2a0de67ad7a42690db071a812f4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 15:51:13 -0400 Subject: [PATCH 28/40] fix --- fast_llm/tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b3795b740..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -201,13 +201,9 @@ def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int tensor = tensor[None] Assert.eq(tensor.shape, self.shape) assert not self._reductions - logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) - logger.info( - f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" - ) Assert.eq(tensor.shape, self.global_shape) return tensor From 2ddc3a748817ee98785344e03809cfd67590e954 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 16:15:10 -0400 Subject: [PATCH 29/40] fix --- fast_llm/engine/config_utils/tensor_space.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index cf2974a99..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -95,7 +95,7 @@ class CompositeTensorDim(TensorDim): def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.is_parallel: + if tensor_dim.parallel_dim is not None: # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim From cef7c155ebe08c40a20b61a1c9f930ee223007f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 30 Jul 2025 12:20:46 -0400 Subject: [PATCH 30/40] fix --- fast_llm/models/ssm/config.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9427f69be..866de962f 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,13 +4,18 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -79,8 +84,7 @@ def _validate(self): self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -90,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -101,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -112,8 +114,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" @classmethod From 8abf2587028c36fa2fab29c12db57a189bfe3c0f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Aug 2025 16:21:17 -0400 Subject: [PATCH 31/40] fixes --- fast_llm/models/gpt/config.py | 4 -- fast_llm/models/ssm/conversion.py | 61 ++++++++++++++++---------- tests/conftest.py | 20 ++------- tests/data/common.py | 2 +- tests/data/test_blending.py | 3 +- tests/data/test_concatenate.py | 3 +- tests/data/test_concatenated_memmap.py | 3 +- tests/data/test_dataset_from_file.py | 3 +- tests/data/test_fim.py | 3 +- tests/data/test_memmap.py | 3 +- tests/data/test_sampling.py | 3 +- tests/data/test_slice.py | 3 +- tests/models/test_checkpoint.py | 4 +- tests/models/test_lm_eval.py | 3 +- tests/models/test_match_megatron.py | 3 +- tests/utils/dataset.py | 26 +++++------ tests/utils/global_variables.py | 48 ++++++++++++++++++++ tests/utils/model_configs.py | 2 +- tests/utils/utils.py | 13 +----- 19 files changed, 125 insertions(+), 85 deletions(-) create mode 100644 tests/utils/global_variables.py diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..3ca2d71fa 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -23,7 +23,6 @@ class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False - trust_remote_code: typing.ClassVar[bool] = False @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -58,17 +57,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "dream" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "diffusion_llama" - trust_remote_code: typing.ClassVar[bool] = True @config_class() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..b5e77e0f0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,8 @@ import pathlib import typing +from transformers import PretrainedConfig + from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -16,7 +18,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig @@ -29,12 +31,14 @@ HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.apriel_hybrid import configuration_ssm_hybrid_apriel, modeling_ssm_hybrid_apriel from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - pass - class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel @@ -523,6 +527,11 @@ class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandle _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat architecture: typing.ClassVar[str] = "AprielSSMForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -635,6 +644,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -648,10 +658,21 @@ class AprielSSMHHybridHuggingfaceCheckpointHandler( format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat _default_block_type: str = SSMBlockType.mamba2_discrete.value architecture: typing.ClassVar[str] = "AprielSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel.__file__ + configuration_file = configuration_ssm_hybrid_apriel.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = modeling_ssm_hybrid_apriel.AprielSSMHybridConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel.AprielSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel.AprielSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -693,6 +714,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -707,28 +729,23 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() - # num_layers = self._model.config.base_model.transformer.num_layers - # # Embedding and output - # if self._model.config.base_model.tie_word_embeddings: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) - # else: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append( - # WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") - # ) - return converters + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), diff --git a/tests/conftest.py b/tests/conftest.py index 19bdfe5d9..86937326c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,27 +8,15 @@ import pytest import xdist.scheduler -from fast_llm.utils import get_and_reset_memory_usage_mib, set_global_variables +from fast_llm.utils import get_and_reset_memory_usage_mib from tests.utils.depends import DependencyManager +from tests.utils.global_variables import TEST_RESULTS_PATH, set_testing_global_variables # TODO: Is this early enough? -set_global_variables() # isort: skip - - -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): - # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. - assert worker_name.startswith("gw") - worker_id = int(worker_name[2:]) - gpus = [int(i) for i in gpus.split(",")] - num_gpus = len(gpus) - gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - +set_testing_global_variables() # isort: skip import torch # isort: skip - from tests.utils.save_load_configs import ( # isort: skip distributed_save_load_config, distributed_save_load_config_non_pp, @@ -44,7 +32,7 @@ ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, TEST_RESULTS_PATH, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip logger = logging.getLogger(__name__) diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..6614accce 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -23,7 +23,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.dataset import TEST_VOCAB_SIZE +from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 3e6c37632..312807aad 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -11,7 +11,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 4f36cdf89..6cc5d639a 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -7,7 +7,8 @@ get_test_data_and_compare_samples, ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index 1cc22250d..35d93d9d5 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -9,7 +9,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_CACHE, get_test_concatenated_memmap_dataset +from tests.utils.dataset import get_test_concatenated_memmap_dataset +from tests.utils.global_variables import DATASET_CACHE _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 3f7d1a139..c149e1395 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -1,7 +1,8 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX def test_dataset_from_file(): diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 004b96289..551134fd2 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -7,7 +7,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index fcd7756db..1286bddd7 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -4,7 +4,8 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 32d76fa4c..a2996aa1c 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -13,7 +13,8 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index f8eedc5bc..1440614cb 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -7,7 +7,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 4bda5512c..031ec6f97 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -317,9 +317,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): if model_testing_config.name in ("diffusion_llama", "dream") else transformers.AutoModelForCausalLM ) - model_as_hf = auto_model.from_pretrained( - hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code - ).cuda() + model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index b9e2aa8c3..8011b5bbc 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,8 +3,9 @@ import pytest -from tests.utils.dataset import TOKENIZER_PATH, download_santacoder_tokenizer +from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 30667cd17..5ff998bfa 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,8 +3,9 @@ import pytest from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import MODEL_DATASET_PREFIX, get_model_test_dataset +from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b770675d4..e4cce2935 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,27 +1,21 @@ import pathlib import random -import string import numpy as np import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from tests.utils.utils import SHARED_RESULT_PATH, TEST_RESULTS_PATH - -# TODO: Fixtures -TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" -TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 - -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" -MODEL_TEST_VOCAB_SIZE = 384 +from tests.utils.global_variables import ( + DATASET_PREFIX, + MODEL_DATASET_PREFIX, + MODEL_TEST_VOCAB_SIZE, + TEST_CHARACTERS, + TEST_DATASET_TOKENS, + TEST_VOCAB_SIZE, + TOKENIZER_FILE, + TOKENIZER_PATH, +) def download_santacoder_tokenizer(): diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py new file mode 100644 index 000000000..80232bf53 --- /dev/null +++ b/tests/utils/global_variables.py @@ -0,0 +1,48 @@ +""" +This files holds global variables and settings that need to be defined before importing any third-party package. +They are kept in a separate file to prevent circular imports. +""" + +import os +import pathlib +import string + +from fast_llm.utils import set_global_variables + +# Directory for all test data and results. +# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") + +WORKER_NAME = os.environ.get("PYTEST_XDIST_WORKER") +GPUS = os.environ.get("CUDA_VISIBLE_DEVICES") +SHARED_RESULT_PATH = TEST_RESULTS_PATH / (f"common_{WORKER_NAME}" if WORKER_NAME else "common") + + +def set_testing_global_variables(): + set_global_variables() # isort: skip + if WORKER_NAME: + if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): + # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. + assert WORKER_NAME.startswith("gw") + worker_id = int(WORKER_NAME[2:]) + gpus = [int(i) for i in gpus.split(",")] + num_gpus = len(gpus) + gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + + +# TODO: Fixtures +TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" +TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +DATASET_CACHE = SHARED_RESULT_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 + +MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 722d8d63a..e9bdeba97 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -24,8 +24,8 @@ AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) -from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 25d5221d8..88303a0f4 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,7 +1,6 @@ import json import logging import math -import os import pathlib import sys import time @@ -19,22 +18,12 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage from fast_llm.utils import get_and_reset_memory_usage_mib, header +from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -# Directory for all test data and results. -# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). -TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") - -# Directory for data that is shared between independent tests and may not be parallel-safe, -# ex. generated dataset and downloaded files. -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - SHARED_RESULT_PATH = TEST_RESULTS_PATH / f"common_{worker_name}" -else: - SHARED_RESULT_PATH = TEST_RESULTS_PATH / "common" - @pytest.fixture(scope="session") def result_path(): From bd4ff0d03fd7f878c6b8d1551ffa682326f2d150 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 12 Aug 2025 14:21:51 -0400 Subject: [PATCH 32/40] doc --- fast_llm/engine/config_utils/tensor_space.py | 86 +++++++++++++++++++- fast_llm/tensor.py | 8 +- 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 6c4b95b20..66176ee0f 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -15,6 +15,16 @@ class TensorDim: + """ + Describes a simple, atomic dimension of a tensor and its size. + The dimension may be parallelized along a distributed dimension `parallel_dim`, + in which case its actual (local) `size` will differ from its `global_size`. + + TensorDim's are used to represent the metadata of tensors through `TensorMeta`. + + This class also serves as a base for more complex tensor dimensions. + """ + def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): # TODO: Handle None for unknown sizes? self._name = name @@ -62,10 +72,25 @@ def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + + Used in`TensorMeta.replace_tensor_parallel_dim`. + """ assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + If the dimension is parallelized, this amounts to gathering along dimension `dim` + and parallel dimension `parallel_dim`, otherwise return the input tensor. + The method needs to be called my all members of the parallel group using their appropriate local slice. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: from fast_llm.core.ops import gather_op @@ -76,6 +101,14 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`. + Unlike `local_to_global`, this method does not need to be called from a distributed setting. + Instead, entries from other ranks are populated with `fill_value`. + + Used in`TensorMeta.local_to_global_partial`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) @@ -84,6 +117,14 @@ def local_to_global_partial( return tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + If the dimension is parallel, this amounts to taking the `rank`th chunk of size `size` along dimension `dim` + and parallel dimension `self.parallel_dim`, otherwise return the input tensor. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the local tensor. + """ return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -92,11 +133,20 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class CompositeTensorDim(TensorDim): + """ + A composite tensor dimension that represent multiple dimensions flattened into ones. + Typically happens for flattened view or higher-dimensional tensors, or tensors that can be expanded as such. + If one of the composed dimensions -- other than the first one -- is parallelized, + this is **not** equivalent to an atomic `TensorDim` of the same size, + as the relation between local and global tensors is different. + + At most one of the sub-dimensions may be parallelized. TODO: Allow for more than one? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): if tensor_dim.parallel_dim is not None: - # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim self._parallel_dim_index = dim @@ -109,12 +159,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self._parallel_dim_index is not None dims = list(self._tensor_dims) dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) return CompositeTensorDim(self.name, tuple(dims)) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global(tensor, dim + i) @@ -124,6 +181,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global_partial(tensor, dim + i) @@ -131,6 +192,9 @@ def local_to_global_partial( return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -138,6 +202,12 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class ConcatenatedTensorDim(TensorDim): + """ + A complex tensor dimension that results from concatenating tensors. + + All sub-dimensions should have the same `parallel_dim` (may be None). TODO: Allow for more complex scenarios? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = tensor_dims[0].parallel_dim for dim, tensor_dim in enumerate(tensor_dims[1:]): @@ -152,12 +222,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self.is_parallel return ConcatenatedTensorDim( self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ import torch return ( @@ -179,6 +256,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ import torch return ( @@ -198,6 +279,9 @@ def local_to_global_partial( ) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ if self.is_parallel and expand: raise NotImplementedError() import torch diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..c17df9d0c 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -240,8 +240,12 @@ def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> return validate_tensor(tensor, self, device) def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": - # Replace the tensor-parallel `DistributedDim` in `meta`. - # Note: This will turn `ParameterMeta` into `TensorMeta` + """ + Replace the tensor-parallel `DistributedDim` in `meta`, preserving the local size. + Requires for advanced tensor manipulations, + ex. turn tensor-parallel slices of a tensor into slices of a different tensor-parallel size. + Note: This will turn `ParameterMeta` into `TensorMeta` + """ if not self.is_tensor_parallel: return self dims = list(self.dims) From 0e2e12402e162c4ad7a378b17522f03a24288ae6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 12 Aug 2025 17:17:07 -0400 Subject: [PATCH 33/40] stuff --- tests/utils/global_variables.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 80232bf53..836b6b79d 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,8 +29,9 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") - os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + # TODO: This might help with some issues, but slows down testing significantly. + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") # TODO: Fixtures From 9a2a7a27018f23f385d566bd9a94bf4affc02813 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:41:55 -0400 Subject: [PATCH 34/40] Pr comments --- fast_llm/layers/ssm/discrete_mamba2.py | 62 ++++++++++++++-------- fast_llm/layers/ssm/mamba2.py | 51 +++++++++++++----- fast_llm/layers/ssm/mamba_layer.py | 36 ++++++++----- fast_llm/layers/transformer/transformer.py | 2 +- 4 files changed, 104 insertions(+), 47 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c9d555de9..b895412b5 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,14 +4,21 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.utils import get_lr_scale +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -49,25 +56,41 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[TransformerDimNames.hidden] - conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] - heads_dim = tensor_space[SSMDimNames.composite_heads] + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) + + head_groups_dim = TensorDim( + "head_groups", + self._config.n_qk_heads, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._local_head_groups = head_groups_dim.size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size + self._local_bc_size = bc_dim.size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -82,15 +105,17 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( ( - conv1d_dim, + convolution_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), + (convolution_dim,), init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale, ) @@ -122,14 +147,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) - # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) - # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) - # Standardize to (batch, padded_sequence, inner_projection) + # Standardize to (batch, padded_sequence, local_inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) - print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -139,9 +162,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) - print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) - print("QAIKOFNMJOWENM z", z.shape) - print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) @@ -189,8 +209,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) a, b = self.out_proj(y) - logger.info(f"EKFBN y {y.shape}") - logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) @torch.compile diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b3869..babbe6e05 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,14 @@ import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -62,13 +69,33 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] - dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) + + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - self._local_heads = tensor_space[SSMDimNames.composite_heads].size - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) + + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -78,7 +105,7 @@ def __init__( ( conv1d_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +117,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + convolution_kernel_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +149,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, convolution_kernel_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, @@ -139,7 +166,7 @@ def __init__( bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), sequence_parallel=self._sequence_parallel, - # TODO: lr_scale? + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -147,10 +174,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, inner_projection) + # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, sequence, local_inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 9343ef1b8..061921b3d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,14 +4,20 @@ import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -67,27 +73,33 @@ def __init__( self._config = config # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # Tensor dims: - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] hidden_dim = tensor_space[TransformerDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) # TODO: Backward compatibility? - # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +107,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space[SSMDimNames.concatenated_x_projection], + x_projection_dim, weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +116,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.dt_rank]), + (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +128,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 75d06f268..63f3aaab6 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -98,7 +98,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space[TransformerDimNames.hidden] - # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) From 8c382a902b91feff0300476be0164773dc47a807 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:43:40 -0400 Subject: [PATCH 35/40] Cleanup --- fast_llm/layers/ssm/config.py | 81 ++--------------------------------- 1 file changed, 4 insertions(+), 77 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b0949d55..e6f87cf27 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,11 +2,10 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.tensor import Initializer @@ -212,77 +211,5 @@ def _validate(self) -> None: Assert.geq(self.dt_max, self.dt_min) def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Head groups are configured differently depending on the block type. - if block_type == SSMBlockType.mamba: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = num_heads - elif block_type == SSMBlockType.mamba2: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = div(self.d_xb, self.state_size) - elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Use different variables? - num_heads = self.n_v_heads - num_head_groups = self.n_qk_heads - else: - raise NotImplementedError(block_type) - - tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) - if block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) - else: - head_dim = state - - tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) - tensor_space.add_tensor_dim( - heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - heads_and_head_dim := CompositeTensorDim( - SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) - ) - ) - tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) - ) - ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) - - # DT projection - if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) - - if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) - ) - # TODO: Use composition instead - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) - ) - ) - elif block_type == SSMBlockType.mamba2: - # TODO: Factor out state? - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), - ) - ) - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_convolution, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state), - ) - ) + # Handled in the model. + pass From 019e43dc6e95a4b2901b1ff3bd8dfacb65af961f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:50:13 -0400 Subject: [PATCH 36/40] Cleanup --- fast_llm/layers/ssm/mamba2.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index babbe6e05..0408479e5 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -13,7 +13,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -44,18 +44,6 @@ class Mamba2(Mixer): _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.composite_heads_and_head_dim, - TransformerDimNames.sequence_q, - ) - _BC_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.composite_heads, - SSMDimNames.state, - TransformerDimNames.sequence_q, - ) - def __init__( self, config: SSMConfig, @@ -168,6 +156,18 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) + if self._debug.enabled: + self._xz_dims = ( + TransformerDimNames.batch, + inner_dim, + TransformerDimNames.sequence_q, + ) + self._bc_dims = ( + TransformerDimNames.batch, + heads_dim, + state_dim, + TransformerDimNames.sequence_q, + ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available @@ -224,11 +224,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dt = dt.transpose(1, 2) if self._debug_level: - self._debug_log(z, "z", self._XZ_DIMS, kwargs) - self._debug_log(x, "x", self._XZ_DIMS, kwargs) - self._debug_log(b, "b", self._BC_DIMS, kwargs) - self._debug_log(c, "c", self._BC_DIMS, kwargs) - self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + self._debug_log(z, "z", self._xz_dims, kwargs) + self._debug_log(x, "x", self._xz_dims, kwargs) + self._debug_log(b, "b", self._bc_dims, kwargs) + self._debug_log(c, "c", self._bc_dims, kwargs) + self._debug_log(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -243,7 +243,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ) if self._debug_level: - self._debug_log(y, "y", self._XZ_DIMS, kwargs) + self._debug_log(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] From 3e0f3e555ab92eca7a62378d6a5ad366f5118bda Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 14:50:38 -0400 Subject: [PATCH 37/40] Cleanup --- fast_llm/layers/ssm/config.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index e6f87cf27..fb178e7d5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -11,28 +11,6 @@ from fast_llm.tensor import Initializer -class SSMDimNames: - # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" - - convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers - - dt_rank = "ssm_dt_rank" - - # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - - # Concatenated dimensions - concatenated_convolution = "ssm_concatenated_convolution" - concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" - - class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. From 1abdd19280f8cf7104236be375aa14ceeea235ee Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:09:24 -0400 Subject: [PATCH 38/40] fixes --- fast_llm/layers/common/config.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 6 ++++-- fast_llm/layers/transformer/transformer.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..710b2668f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -43,7 +43,7 @@ class NormalizationConfig(BaseModelConfig): pass @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": pass @classmethod @@ -63,7 +63,7 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": return torch.nn.Identity() diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b895412b5..47a94214a 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -63,7 +63,7 @@ def __init__( head_groups_dim = TensorDim( "head_groups", self._config.n_qk_heads, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 0408479e5..95febb1c6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -64,7 +64,9 @@ def __init__( state_dim = TensorDim("state", self._config.state_size) head_groups_dim = TensorDim( - "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + "head_groups", + num_head_groups, + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) @@ -105,7 +107,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - convolution_kernel_dim, + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 63f3aaab6..c7becd948 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -100,8 +100,9 @@ def __init__( hidden_dim = self._tensor_space[TransformerDimNames.hidden] # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale) # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) From 7c2429292e5b56a763d39d67096d2931e657098d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:49:49 -0400 Subject: [PATCH 39/40] fixes --- fast_llm/layers/ssm/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 95febb1c6..802d757eb 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -158,7 +158,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - if self._debug.enabled: + if self._debug_level: self._xz_dims = ( TransformerDimNames.batch, inner_dim, From af2964bfee592db2a59789cde413deba3acd3d1d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:59:44 -0400 Subject: [PATCH 40/40] fixes --- fast_llm/layers/ssm/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 802d757eb..7151da394 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -139,7 +139,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, convolution_kernel_dim), + (inner_dim, state_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False,