diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index bf11038a5..c6912e4f1 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR begin_step = 0 for stage_arg_str in config.schedule.split(";"): try: - for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","): - assert begin_step is not None - num_steps = int(num_steps) - end_step = None if num_steps < 0 else begin_step + num_steps - kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} - if len(stage_args) > 0: - kwargs["end_lr"] = float(stage_args[0]) - if len(stage_args) > 1: - kwargs["power"] = float(stage_args[1]) - if len(stage_args) > 2: - raise ValueError(stage_args[2:]) - stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) - begin_step = end_step + stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",") + assert begin_step is not None + num_steps = int(num_steps) + end_step = None if num_steps < 0 else begin_step + num_steps + kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} + if len(stage_args) > 0: + kwargs["end_lr"] = float(stage_args[0]) + if len(stage_args) > 1: + kwargs["power"] = float(stage_args[1]) + if len(stage_args) > 2: + raise ValueError(stage_args[2:]) + stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) + begin_step = end_step except Exception: raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"') return LearningRateSchedule(stages) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 054c26c3c..1b84eeb2c 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -11,6 +11,18 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +@config_class() +class LLMBlockConfig(BaseModelConfig): + _abstract = False + + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. @@ -68,7 +80,7 @@ class NormalizationConfig(BaseModelConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ @@ -77,6 +89,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": "eps": self.epsilon, "implementation": self.implementation, "zero_centered": self.zero_centered, + "lr_scale": lr_scale, } if self.initialization_range: mean = 0 if self.zero_centered else 1 diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 848abb974..5f30beaef 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -155,6 +155,7 @@ def __init__( weight_init_method=None, bias_init_method=init_zeros_, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -193,12 +194,14 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape @@ -236,6 +239,7 @@ def __init__( implementation: NormalizationImplementation = NormalizationImplementation.auto, weight_init_method=None, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -269,6 +273,7 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=True, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2d5fd8436..6e6a8ae56 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -155,6 +155,25 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -173,6 +192,10 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: self.transformer.setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..e0386d8df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d6d1b8a54..233887ec6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -60,6 +60,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._loss_coefficient = ( + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = config.transformer.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor @@ -109,6 +112,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( @@ -139,7 +143,7 @@ def forward( else: if self.training: # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient) # MTP: Return shared_hidden to be used by the next head. return shared_hidden diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 25ad3d225..13418254c 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,8 @@ -from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig from fast_llm.utils import Assert @@ -20,8 +21,19 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads +class SSMBlockType(str, enum.Enum): + """ + An enum for the available mamba types for the MLP layer. + """ + + mamba = "m" + mamba2_discrete = "m2d" + mamba2 = "m2" + transformer = "t" + + @config_class() -class SSMConfig(BaseModelConfig): +class SSMConfig(LLMBlockConfig): _abstract = False # Normalization @@ -53,7 +65,8 @@ class SSMConfig(BaseModelConfig): desc="Whether to use bias in SSM layers", hint=FieldHint.architecture, ) - dt_rank: int = Field( + + dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, @@ -102,12 +115,22 @@ class SSMConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + mamba_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for Mamba blocks.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu - if self.dt_rank is None: - self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation super()._validate() Assert.geq(self.dt_max, self.dt_min) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index bf0128c89..85916244e 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,6 +1,6 @@ +import logging import math -import causal_conv1d import einops import mamba_ssm.ops.triton.ssd_combined import torch @@ -9,6 +9,16 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.utils import get_lr_scale + +logger = logging.getLogger(__name__) + +try: + import causal_conv1d +except ImportError: + # this is needed since we cannot use causal_conv1d on B200 GPUs for now + logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead") + causal_conv1d = None """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py @@ -44,6 +54,9 @@ def __init__( bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -67,31 +80,41 @@ def __init__( # TODO: double check initializations # Projections - self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) + self.in_proj = Linear( + td_model, + td_inner_proj, + bias=bias, + weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, + ) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 ) - # Convolutional layer self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + lr_scale=mamba_layer_lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) - self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) # D "skip" parameter self.D = ParameterMeta.from_dims( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -100,6 +123,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) @property @@ -210,10 +234,25 @@ def forward(self, hidden_states, kwargs): def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - xBC = causal_conv1d.causal_conv1d_fn( - xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, - ).transpose(1, 2) + if causal_conv1d is None or self.activation_name not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act( + torch.nn.functional.conv1d( + xBC.transpose(1, 2), + self.conv1d_weight, + bias=self.conv1d_bias, + groups=self.conv1d_weight.shape[0], + padding=self.conv_kernel_size - 1, + )[..., :padded_len].transpose(1, 2) + ) + else: + xBC = causal_conv1d.causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_bias, + activation=None if self.activation_name == "identity" else self.activation_name, + ).transpose(1, 2) return xBC diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 3d9cc05b8..7d0ee48a4 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. @@ -81,6 +82,8 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), @@ -90,6 +93,7 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = None @@ -102,6 +106,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -110,6 +115,7 @@ def __init__( self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), init_method=kaiming_init_(tdt_rank.size), + lr_scale=mamba_layer_lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( @@ -117,12 +123,14 @@ def __init__( init_method=init_dtprojbias( self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs ), + lr_scale=mamba_layer_lr_scale, ) self.A_log = ParameterMeta.from_dims( (td_inner, td_state), weight_decay=False, init_method=init_A(self.d_state, self.d_inner), + lr_scale=mamba_layer_lr_scale, ) # D "skip" parameter @@ -130,6 +138,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) self.out_proj = Linear( @@ -137,6 +146,7 @@ def __init__( td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 0b442f661..0517c49ce 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -110,6 +110,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, @@ -118,7 +121,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -127,7 +130,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -139,7 +142,7 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e7ef0b15f..9cc9510b5 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -11,7 +11,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig, PeftConfig, PeftType from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -257,7 +257,7 @@ def _validate(self) -> None: @config_class() -class TransformerConfig(BaseModelConfig): +class TransformerConfig(LLMBlockConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f4..49778c63f 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -21,7 +21,7 @@ from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." @@ -59,6 +59,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + self.router = Linear( tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), @@ -66,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max ), - lr_scale=config.router_lr_scale, + lr_scale=router_lr_scale, ) dropless_moe = config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f9..c4d8afdc7 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -10,13 +10,14 @@ from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): super().__init__() self._name = name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -38,6 +39,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, @@ -45,7 +50,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) # PEFT. @@ -64,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 40dd2e00e..b51ba1e94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -38,13 +38,15 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + # Note, layer_lr_scale does not impact the norms + # TODO: add a seperate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index ) # PEFT. diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 771a4fcaf..75bb2a3be 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,11 +9,12 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames +from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.models.gpt.model import GPTInferenceRunner from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.models.ssm.trainer import SSMTrainer @@ -29,10 +30,10 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] = Field( - default_factory=lambda: ["m2"], - desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2", - hint=FieldHint.architecture, + hybrid_block_layout: list[str] | None = Field( + default=None, + desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", + hint=FieldHint.core, ) default_mtp_type: str | None = Field( default=None, @@ -49,17 +50,24 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - if not "m2" in self.hybrid_block_layout and not "m" in self.hybrid_block_layout: - raise ValueError( - "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" - ) - - if self.ssm.dt_rank < 0: + # if ( + # not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout + # and not SSMBlockType.mamba.value in self.hybrid_block_layout + # ): + # raise ValueError( + # f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" + # ) + + if self.ssm.dt_rank is None: mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) else: mamba_dt_rank = self.ssm.dt_rank - d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + d_inner = ( + int(self.ssm.expansion_factor * self.transformer.hidden_size) + if self.ssm.d_inner is None + else self.ssm.d_inner + ) # Hidden dimension tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) # Mamba-specific dimensions @@ -70,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - if "m2" in self.hybrid_block_layout or self.default_mtp_type == "m2": + if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: # Mamba2 specific dimensions # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 headdim = d_inner // self.ssm.n_v_heads @@ -88,26 +96,29 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) def _validate(self): - len_block_layout = len(self.hybrid_block_layout) - if len_block_layout != self.transformer.num_layers: - if self.transformer.num_layers % len_block_layout != 0: + if self.hybrid_block_layout is None: + with self._set_implicit_default(): + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + + if len(self.hybrid_block_layout) != self.transformer.num_layers: + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" ) - num_repeats = int(self.transformer.num_layers // len_block_layout) + num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) logger.warning( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" ) self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len_block_layout, self.transformer.num_layers) + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) Assert.custom( - lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", + lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", ) Assert.custom( - lambda _: self.default_mtp_type in ["t", "m", "m2", None], - f"Invalid MTP type: {self.default_mtp_type}. Must be 't' or 'm' or 'm2' or None", + lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, + f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", ) super()._validate() @@ -124,11 +135,50 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler +class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler + + return AprielSSMHuggingfaceCheckpointHandler + + +class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler + + return AprielSSMHHybridHuggingfaceCheckpointHandler + + +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" base_model: HybridSSMBaseModelConfig = FieldUpdate() + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( + LLambaHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + AprielSSMHHybridHuggingfaceCheckpointFormat, + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + ) checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -160,9 +210,39 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: from fast_llm.models.ssm.trainer import SSMTrainer return SSMTrainer + + def _validate(self) -> None: + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? + logger.warning( + "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" + ) + + return GPTInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 190b2ffae..c5f22ccc5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -5,7 +5,9 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConstantExportParamConverter, ConstantImportParamConverter, + IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, ParamConverter, @@ -17,8 +19,15 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType -from fast_llm.models.gpt.conversion import MLPLayer2Converter -from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + HybridSSMModelConfig, + LLambaHuggingfaceCheckpointFormat, +) from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -26,74 +35,46 @@ pass -class LLambaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + """ + This is a temporary solution for importing/exporting hybrid models. Since there is no standard solution for this in HF, we just use the block_pattern. + If block_pattern is None, it will multiply the provided default block type by the number of layers and export/import it. + If block_pattern is provided, it will export/import it as-is. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _import_config(cls, config): + cls.num_layers = config["n_layer"] if "n_layer" in config else config["num_hidden_layers"] + cls.block_pattern = config.get("hybrid_block_layout", None) + return super()._import_config(config) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + if cls.block_pattern is not None: + block_converter = RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + ) + else: + block_converter = ConstantImportParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + fast_llm_value=[cls._default_block_type] * cls.num_layers, + ) + + return super()._create_config_converters() + [block_converter] + + +class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - # TODO: is there an equivalen of pad_vocab_size_multiple in FastLLM, does it matter? - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - RenameParamConverter( - fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( fast_llm_names=(("ssm", "state_size"),), export_names=( @@ -162,6 +143,142 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() or [] + + num_layers = self._model.config.base_model.transformer.num_layers + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"model.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"model.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + return converters + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + _hf_prefix: str = "backbone" + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + """ + Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json + """ + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("n_layer",),), + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=( + ( + "mlp_cfg", + "act_fn", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), + export_names=( + ( + "mlp_cfg", + "bias", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=( + ( + "mlp_cfg", + "intermediate_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + # not using super() because LLamba model is called backbone in the checkpoints converters = [] num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False @@ -169,58 +286,68 @@ def _create_weight_converters(self) -> list[WeightConverter]: # Embedding and output if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append( + WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + ) # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias ) for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{i+1}.mixer.z_bias", + f"{self._hf_prefix}.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # Norm converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias ) # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") return converters @@ -282,3 +409,253 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: with open(directory / "config.json", "w") as f: json.dump(config, f) + + +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + """ + Lamba-like configs, pure SSM models. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + ) + + for i in range(num_layers): + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielThinkerSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + _hf_prefix: str = "model" + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + # num_layers = self._model.config.base_model.transformer.num_layers + # # Embedding and output + # if self._model.config.base_model.tie_word_embeddings: + # converters.append( + # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + # ) + # converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) + # else: + # converters.append( + # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + # ) + # converters.append( + # WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + # ) + return converters + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py new file mode 100644 index 000000000..84242b7db --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -0,0 +1,30 @@ +from transformers import MistralConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ssm_config_default = { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py new file mode 100644 index 000000000..03e379e1d --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -0,0 +1,1213 @@ +import copy +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + # In __init__, pre-allocate these tensors + self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) + self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + attention_mask: Optional[torch.Tensor] = None, + return_mixer_matrix=False, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + For later refference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modeling_bamba.py + """ + assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" + cache_position = kwargs.get("cache_position", None) + batch, seqlen, dim = u.shape + u = apply_mask_to_padding_states(u, attention_mask) + ssm_state, conv_state = None, None + + use_precomputed_states = ( + past_key_value is not None + and past_key_value.has_previous_state + and seqlen == 1 + and past_key_value.conv_states[self.layer_idx].shape[0] + == past_key_value.ssm_states[self.layer_idx].shape[0] + == batch + and cache_position is not None + and cache_position[0] > 0 + ) + if use_precomputed_states: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + else: + outputs = {} + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + # def forward_original( + # self, + # u, + # return_mixer_matrix=False, + # past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + # inference_params=None, + # **kwargs, + # ): + # """ + # u: (B, L, D) + # Returns: same shape as u + # """ + # outputs = {} + # # assert state is None + # batch, seqlen, dim = u.shape + + # ssm_state, conv_state = None, None + # if past_key_value is not None: + # ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + # cache_position = kwargs.get("cache_position", None) + # # if inference_params is not None and inference_params.seqlen_offset > 0: + # if cache_position is not None and cache_position[0] > 0: + # # States are updated inplace + # # TODO: make sure using cache_position is correct here + # u = u.squeeze(1) if len(u.shape) == 3 else u + # out, _, _ = self.step(u, ssm_state, conv_state) + # out = out.unsqueeze(1) if len(u.shape) == 2 else out + # return {"hidden_states": out} + + # # Hacky way to initialize state during inference + # chunk_size = self.chunk_size if ssm_state is None else seqlen + + # # Pad input to nearest multiple of chunklen + # padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + # u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # # Project input + # xBCzA_log = self.in_proj(u) + # xBC, z, A_log = torch.split( + # xBCzA_log, + # [ + # self.d_inner + 2 * self.n_qk_heads * self.d_state, + # self.d_inner, + # self.n_v_heads, + # ], + # dim=-1, + # ) + + # if ssm_state is not None: + # # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + # xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + # conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # # Convolutional layer + # xBC = self.convolutional_forward(xBC, padded_len) + + # x, B, C = torch.split( + # xBC, + # [ + # self.d_inner, + # self.n_qk_heads * self.d_state, + # self.n_qk_heads * self.d_state, + # ], + # dim=-1, + # ) + + # x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + # B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + # C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # # SSM forward + # result = mamba_chunk_scan_combined( + # x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + # dt=A_log, + # dt_softplus=True, + # A=-torch.ones(self.n_v_heads, device=A_log.device), + # B=B, + # C=C, + # chunk_size=chunk_size, + # # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + # return_final_states=(ssm_state is not None), + # ) + + # if ssm_state is not None: + # y, ssm_state_update = result + # ssm_state.copy_(ssm_state_update) + # else: + # y = result + + # Du = torch.einsum("h,blhp->blhp", self.D, x) + # y = rearrange(y + Du, "b l h p -> b l (h p)") + + # # Norm and gate + # out = self.out_proj(y * F.silu(z + self.z_bias)) + # outputs["hidden_states"] = out[:, :seqlen, :] + + # if return_mixer_matrix: + # outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + # return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + if conv_state_new is not None: + raise NotImplementedError("Should not end up here snce only support fast path.") + # conv_state.copy_(conv_state_new) # update state in place, only for slow pass + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate + ones = self.ones_buffer.to(A_log.device).to(x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + return xBC, None + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + +class AprielHybridIdentity(nn.Module): + def __init__(self, config: AprielSSMHybridConfig): + super().__init__() + self.config = config + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return (hidden_states,) + + +class AprielSSMHybridModel(MistralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + config_copy = copy.deepcopy(config) + config_copy.num_hidden_layers = 0 + super().__init__(config_copy, **kwargs) + self.config = config + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx)) + elif type == "t": + blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "i": + blocks.append(AprielHybridIdentity(config)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielHybridPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return AprielHybridCausalOutput( + loss=loss, + logits=logits, + all_hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py new file mode 100644 index 000000000..1d230bb67 --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -0,0 +1,448 @@ +import math +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + # Apriel: Use original max_position_embeddings instead of max_position_embeddings + max_position_embeddings = config.rope_scaling.get( + "original_max_position_embeddings", config.max_position_embeddings + ) + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "yarn": _compute_yarn_parameters, +} + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "yarn": _validate_yarn_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) + + +class AprielSSMHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AprielModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + ```python + >>> from transformers import AprielModel, AprielConfig + >>> # Initializing an Apriel Apriel-5B-Base style configuration + >>> configuration = AprielConfig() + >>> # Initializing a model from the Apriel-5B-Base style configuration + >>> model = AprielModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "apriel_ssm_hybrid" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `AprielModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + hybrid_block_layout=["m2d"], + ssm_cfg=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.hybrid_block_layout = hybrid_block_layout + if len(hybrid_block_layout) == 1: + self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers + assert len(self.hybrid_block_layout) == self.num_hidden_layers + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + ssm_defaults = { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + self.ssm_cfg = ssm_cfg or ssm_defaults + for k, v in ssm_defaults.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py new file mode 100644 index 000000000..ddb7d0f77 --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -0,0 +1,1568 @@ +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( + ROPE_INIT_FUNCTIONS, + AprielSSMHybridConfig, +) + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class AprielRotaryEmbedding(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class AprielAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if inference_params is not None and inference_params.seqlen_offset > 0: + # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # """Allocate inference cache for the model.""" + # if getattr(self.mixer, "allocate_inference_cache", None) is None: + # return + # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMHybridConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMHybridModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) + elif type == "t": + blocks.append(AprielDecoderLayer(config, layer_idx)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + self.rotary_emb = AprielRotaryEmbedding(config=config) + self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # cache = {} + # for i, layer in enumerate(self.layers): + # if isinstance(layer, AprielSSMDecoderLayer): + # cache[i] = layer.allocate_inference_cache(*args, **kwargs) + # return cache + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + inference_params=None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + # past_key_values = HybridMambaAttentionDynamicCache() + logger.warning_once( + "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None and self.has_transformer_layers: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None and self.has_transformer_layers: + position_ids = cache_position.unsqueeze(0) + + causal_mask = ( + self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + if self.has_transformer_layers + else None + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + inference_params=inference_params, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( + past_key_values, HybridMambaAttentionStaticCache + ) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + num_last_tokens=0, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) + # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model + outputs: BaseModelOutputWithPast = self.model( + input_ids, + return_hidden_states=return_hidden_states, + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return AprielHybridCausalOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs.hidden_states, + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py new file mode 100644 index 000000000..6943a3124 --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py @@ -0,0 +1,103 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apriel SSM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + pass + + +class AprielSSMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + .... + ```""" + + model_type = "apriel_ssm" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mlp_bias=False, + rms_norm_eps=1e-5, + ssm_cfg: dict = None, + head_dim: int = 128, + **kwargs, + ): + self.vocab_size = vocab_size + # self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + # self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + # self.rope_theta = rope_theta + self.mlp_bias = mlp_bias + self.head_dim = head_dim + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is not equal to d_inner // n_qk_heads.") + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py new file mode 100644 index 000000000..09dc8259c --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py @@ -0,0 +1,743 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] + Args: + config: AprielSSMConfig + """ + + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + self.layers = nn.ModuleList( + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ) -> Union[tuple, BaseModelOutputWithPast]: + + hidden_states = self.embed_tokens(input_ids) + + # decoder layers + outputs = { + "last_hidden_state": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + + layer_outputs = decoder_layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + outputs["last_hidden_state"] = self.norm(hidden_states) + return outputs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.model = AprielSSMModel(config, device=device, dtype=dtype) + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py new file mode 100644 index 000000000..de1b03842 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,180 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for AprielSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm_15b") +class AprielHybrid15bSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( + AprielSSMHybridConfig, + ) + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMHybridForCausalLM, + ) + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/eval/run_evalchemy.py b/fast_llm/models/ssm/external/eval/run_evalchemy.py new file mode 100644 index 000000000..1cbb5b4da --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_evalchemy.py @@ -0,0 +1,9 @@ +from eval.eval import cli_evaluate +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybrid15bSSMWrapper, + AprielHybridSSMWrapper, + AprielSSMWrapper, +) + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py new file mode 100644 index 000000000..53c0febab --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -0,0 +1,10 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybrid15bSSMWrapper, + AprielHybridSSMWrapper, + AprielSSMWrapper, +) + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py new file mode 100644 index 000000000..b8173b733 --- /dev/null +++ b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py @@ -0,0 +1,94 @@ +from enum import Enum + +from transformers.configuration_utils import PretrainedConfig + + +class StateUpdateKernel(Enum): + ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet + cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html + ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification + + +class MTPLlambaConfig(PretrainedConfig): + r"""Configuration class for the CustomMamba model. + + This configuration is used to instantiate the CustomMamba model according to the specified arguments, + defining the model architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the model. + tie_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + pad_vocab_size_multiple (`int`, *optional*, defaults to 8): + Pad the vocabulary size up to the next multiple of this value. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether the LM head includes a bias term. + d_model (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + lm_head_prenorm (`str`, *optional*, defaults to "rms"): + Normalization type for LM head. + n_layer (`int`, *optional*, defaults to 32): + Number of layers in the model. + resid_dropout (`float`, *optional*, defaults to 0.0): + Dropout rate for residual connections. + norm_epsilon (`float`, *optional*, defaults to 1e-5): + Epsilon value used for normalization layers. + mlp_cfg (`dict`, *optional*): + Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. + ssm_cfg (`dict`, *optional*): + Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. + + """ + + model_type = "llamba" + + def __init__( + self, + vocab_size: int, + d_model: int, + tie_embeddings: bool = False, + pad_vocab_size_multiple: int = 8, + lm_head_bias: bool = False, + n_layer: int = 32, + resid_dropout: float = 0.0, + norm_epsilon: float = 1e-5, + mlp_cfg: dict = None, + ssm_cfg: dict = None, + prediction_heads=1, + state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.tie_embeddings = tie_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.lm_head_bias = lm_head_bias + self.d_model = d_model + self.n_layer = n_layer + self.resid_dropout = resid_dropout + self.norm_epsilon = norm_epsilon + self.prediction_heads = prediction_heads + assert ( + state_update_kernel != StateUpdateKernel.ssu_verification + ), "Only chunk scan and standard modes are supported for now" + self.state_update_kernel = state_update_kernel + + # MLP (Multi-Layer Perceptron) Config + self.mlp_cfg = mlp_cfg or { + "intermediate_size": 14336, + "bias": False, + "act_fn": "silu", + } + + # SSM (State Space Model) Config + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + } diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py new file mode 100644 index 000000000..6d9746db1 --- /dev/null +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -0,0 +1,389 @@ +# Copyright (c) 2024, Kevin Li, Aviv Bick. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from mamba_ssm.utils.generation import GenerationMixin +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.utils.generic import ModelOutput + +from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig +from .discrete_mamba2 import DiscreteMamba2 + + +class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" + + def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + """Set the extra representation of the module.""" + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LlamaMLP(nn.Module): + """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" + + def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) + self.act_fn = ACT2FN[act_fn] + + def forward(self, x): + """ + Args: + x: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): + """MambaLM model with a language modeling head on top (linear layer).""" + + def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: + super().__init__() + + # Load config + if not isinstance(config, LlambaConfig): + config = LlambaConfig(**config) + self.config = config + + # Factory kwargs + factory_kwargs = {"device": device, "dtype": dtype} + + # Pad vocab size to be a multiple of pad_vocab_size_multiple + vocab_size = config.vocab_size + pad_vocab_size_multiple = config.pad_vocab_size_multiple + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.config.vocab_size = vocab_size + + # Mixer model + self.backbone = MixerModel( + input_size=vocab_size, + config=self.config, + initializer_cfg=initializer_cfg, + **factory_kwargs, + ) + + # MTP heads + self.mtp_heads = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=layer_idx, + ).to(device) + for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) + ] + ) + + self.mtp_norms = nn.ModuleList( + [ + LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) + for _ in range(config.prediction_heads - 1) + ] + ) + # LM head + if not self.config.tie_embeddings: + self.lm_head = nn.Linear( + in_features=self.config.d_model, + out_features=self.config.vocab_size, + bias=self.config.lm_head_bias, + **factory_kwargs, + ) + else: + self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + + mtps = { + i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) + for i, layer in enumerate(self.mtp_heads) + } + return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} + + def forward( + self, + input_ids, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + ): + """ + Args: + input_ids: torch.Tensor of shape (batch_size, seq_len), + position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), + return_hidden_states: bool, optional, + return_logits: bool, optional, whether to compute the logits with the LM head, + inference_params: dict, optional, the model's inference cache, + num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. + + Returns: + CustomMambaCausalLMOutput. + + """ + outputs = self.backbone( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + # MTP heads processing + latents = [] + hidden_states = outputs["last_hidden_state"] + hidden_states_before_last = outputs["hidden_state_before_last"] + + # last layer already has layer norm applied + latents.append(hidden_states) + + # Process through MTP heads + for i, mtp_head in enumerate(self.mtp_heads): + mtp_outputs = mtp_head( + hidden_states_before_last, + inference_params=inference_params, + position_ids=position_ids, + ) + mtp_hidden_states = mtp_outputs["hidden_states"] + latents.append(self.mtp_norms[i](mtp_hidden_states)) + + # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) + stacked_latents = torch.stack(latents, dim=-2) + + if return_logits: + if isinstance(self.lm_head, nn.Linear): + # Apply lm_head to each prediction head's output + logits = self.lm_head(stacked_latents).float() + else: + # Using the tied embedding weights + logits = self.lm_head(stacked_latents) + + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=stacked_latents, + ) + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Save the model's state_dict + model_path = os.path.join(save_directory, "pytorch_model.bin") + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, "config.json") + with open(config_path, "w") as f: + json.dump(self.config.to_dict(), f) + + +class MixerModel(nn.Module): + """Mixer model with a stack of Mixer layers.""" + + def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config = config + self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) + + self.layers = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=i, + ).to(device) + for i in range(self.config.n_layer) + ] + ) + + self.final_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, + eps=self.config.norm_epsilon, + factory_kwargs=factory_kwargs, + ) + + return + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + def forward( + self, + input_ids, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ): + """Run the model.""" + # Start running the layers + hidden_states = self.embedding(input_ids) + + # Initialize outputs + outputs = { + "last_hidden_state": None, + "hidden_state_before_last": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + # Run the layers + for layer in self.layers: + layer_outputs = layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + if layer == self.layers[-1]: + outputs["hidden_state_before_last"] = hidden_states + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + # Last layer, apply layer norm + outputs["last_hidden_state"] = self.final_layernorm(hidden_states) + return outputs + + +class Block(nn.Module): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + + def __init__(self, config, factory_kwargs, layer_idx, **kwargs): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + # Mixer + self.mixer = DiscreteMamba2( + d_model=self.config.d_model, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + # Other components + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.post_attention_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + ) + self.mlp = LlamaMLP( + hidden_size=self.config.d_model, + **config.mlp_cfg, + factory_kwargs=factory_kwargs, + ) + + def forward( + self, + hidden_states: Tensor, + inference_params=None, + **kwargs, + ): + """ + Pass the input through the encoder layer. + + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + inference_params: dict, optional, + + Returns: + dict with keys: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). + """ + outputs = {} + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Apply Mixer + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 6ff6c5f52..118a195b8 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -11,7 +11,7 @@ from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def get_output_layers(self) -> list[Layer]: if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): - if block_type == "t": + if block_type == SSMBlockType.transformer.value: layers.append( TransformerLayer( self._config.transformer, @@ -53,7 +53,7 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == "m2": + elif block_type == SSMBlockType.mamba2_discrete.value: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -63,7 +63,7 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) layers.append(mamba_block) - elif block_type == "m": + elif block_type == SSMBlockType.mamba.value: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -74,7 +74,7 @@ def get_output_layers(self) -> list[Layer]: ) layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -88,7 +88,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == "t": + if block_type == SSMBlockType.transformer.value: # Transformer block layers.append( TransformerLayer( @@ -100,7 +100,7 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == "m2": + elif block_type == SSMBlockType.mamba2_discrete.value: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -112,7 +112,8 @@ def get_layers(self) -> list[Layer]: ), ) layers.append(mamba_block) - elif block_type == "m": + + elif block_type == SSMBlockType.mamba.value: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, @@ -126,7 +127,7 @@ def get_layers(self) -> list[Layer]: ) layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 61ac1014c..bbee3a986 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -339,6 +339,24 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) +def get_lr_scale( + lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None +) -> float | None | tuple[float | None, ...]: + """ + Combine module and layer lr_scale. + If one is None, return the other. + """ + if lr_scale is None: + return layer_lr_scale + if layer_lr_scale is None: + return lr_scale + if isinstance(lr_scale, float): + return lr_scale * layer_lr_scale + if isinstance(lr_scale, tuple): + return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) + raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") + + class Interrupter: def __init__(self, enabled: bool = True, signals: typing.Sequence[int] = (signal.SIGINT, signal.SIGTERM)): self._enabled = enabled diff --git a/setup.cfg b/setup.cfg index 381225bf8..1cde57d16 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ CORE = safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.2.post1 - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm==2.2.4 # Required for some optional features and tools. @@ -42,6 +42,8 @@ OPTIONAL = # Miscellaneous requests>=2.32.3 tqdm>=4.66.3 + # For causal_conv1d + causal_conv1d>=1.4.0 DEV = # Pre-commit git hook diff --git a/tests/test_mtp.py b/tests/test_mtp.py index edce4e74d..71c55e0fc 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -9,6 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -20,7 +21,7 @@ from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.models.ssm.model import HybridSSMBaseModel -except ImportError: +except Exception: MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( None, None, @@ -135,11 +136,11 @@ def test_transformer_mtp(config_dict: dict[str, typing.Any]): @pytest.mark.parametrize( ("hybrid_block_layout", "prediction_heads", "default_mtp_type"), [ - (["m", "t"], 1, None), - (["t", "m"], 2, None), - (["m", "t"], 2, None), - (["t", "m2"], 3, None), - (["t", "m2"], 3, "m"), + ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 1, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba.value], 2, None), + ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 2, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, SSMBlockType.mamba.value), ], ) def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_heads, default_mtp_type): @@ -154,7 +155,11 @@ def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_he model.to("cuda") num_heads, num_mtp_blocks = 0, 0 - str_block_mapping = {"t": TransformerLayer, "m": MambaLayer, "m2": DiscreteMamba2} + str_block_mapping = { + SSMBlockType.transformer: TransformerLayer, + SSMBlockType.mamba: MambaLayer, + SSMBlockType.mamba2_discrete: DiscreteMamba2, + } mtp_block_type = default_mtp_type or hybrid_block_layout[-1] for block in model.get_output_layers(): if isinstance(block, LanguageModelHead): diff --git a/tests/test_ssms.py b/tests/test_ssms.py index a6922a454..f3eb92617 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -13,9 +13,10 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat from tests.common import get_hybrid_config, materialize_meta_tensors try: @@ -23,14 +24,13 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -except ImportError: +except Exception: MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( None, None, None, None, ) - # Mamba not installed, skipping tests try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel @@ -139,13 +139,58 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +def get_hf_apriel_hybrid_out(input_ids, path, format): + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum + + +@pytest.mark.slow +@pytest.mark.skipif( + not run_test + and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), + reason=f"Skipping because no CUDA available or Mamba not installed", +) +def test_load_from_hybridssm_checkpoint(distributed_config): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 + + path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") + format = AprielSSMHHybridHuggingfaceCheckpointFormat + + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) + hf_logits = hf_logits["logits"].cpu() + + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridSSMModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + + @pytest.mark.extra_slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", [ - (["m", "t"], MambaLayer), - (["m2", "t"], DiscreteMamba2), + ([SSMBlockType.mamba, SSMBlockType.transformer], MambaLayer), + ([SSMBlockType.mamba2_discrete, SSMBlockType.transformer], DiscreteMamba2), ], ids=["mamba", "discrete_mamba2"], ) @@ -214,7 +259,7 @@ def test_mamba_block(distributed_config, distributed): ("hybrid_block_layout"), [ (["m", "t"]), - (["m2", "t"]), + (["m2d", "t"]), ], ids=["mamba", "discrete_mamba2"], ) @@ -301,3 +346,6 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # }, # losses=losses, # ) + +if __name__ == "__main__": + pytest.main(["-s", __file__])