diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f233d2f44..49dacb914 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -29,6 +29,7 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, + return_input: bool = False, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. @@ -42,6 +43,7 @@ def __init__( self.config: SSMConfig = config bias = config.add_bias_linear self.layer_idx = layer_idx + self._return_input = return_input td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -122,10 +124,10 @@ def forward(self, hidden_states, kwargs): outputs["hidden_states"]: (B, L, D). outputs["state"]: inference cache. """ - u = hidden_states + input_ = hidden_states outputs = {} # assert state is None - batch, seqlen, dim = u.shape + batch, seqlen, dim = input_.shape state = None @@ -134,7 +136,7 @@ def forward(self, hidden_states, kwargs): # Pad input to nearest multiple of chunklen padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(u, (0, 0, 0, padded_len - seqlen)) + u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) # Project input xBCzA_log = self.in_proj(u) @@ -198,10 +200,13 @@ def forward(self, hidden_states, kwargs): # Norm and gate out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] + outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + + if self._return_input: + return torch.stack([input_, outputs["hidden_states"]], dim=0) # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"].contiguous(), None + return outputs["hidden_states"], None def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 1695cf2fb..4704b5228 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -57,6 +57,7 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, + return_input: bool = False, ): factory_kwargs = {} super().__init__() @@ -139,6 +140,7 @@ def __init__( **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True + self._return_input = return_input def forward(self, hidden_states, kwargs): batch, seqlen, dim = hidden_states.shape @@ -170,4 +172,6 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) + if self._return_input: + out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index d04767b98..0311cc695 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -32,9 +32,14 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): ) hybrid_block_layout: list[str] = Field( default_factory=lambda: ["m2"], - desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.", + desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2", hint=FieldHint.architecture, ) + default_mtp_type: str | None = Field( + default=None, + desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + hint=FieldHint.optional, + ) use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) # TODO: is this needed? @@ -66,7 +71,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - if "m2" in self.hybrid_block_layout: + if "m2" in self.hybrid_block_layout or self.default_mtp_type == "m2": # Mamba2 specific dimensions # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 headdim = d_inner // self.ssm.n_v_heads @@ -84,23 +89,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) def _validate(self): - if len(self.hybrid_block_layout) != self.transformer.num_layers: - len_block_layout = len(self.hybrid_block_layout) + len_block_layout = len(self.hybrid_block_layout) + if len_block_layout != self.transformer.num_layers: if self.transformer.num_layers % len_block_layout != 0: raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}" ) num_repeats = int(self.transformer.num_layers // len_block_layout) logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" ) self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) + Assert.eq(len_block_layout, self.transformer.num_layers) Assert.custom( lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", ) + Assert.custom( + lambda _: self.default_mtp_type in ["t", "m", "m2", None], + f"Invalid MTP type: {self.default_mtp_type}. Must be 't' or 'm' or 'm2' or None", + ) super()._validate() diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 33d2c185c..6ff6c5f52 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -20,7 +20,7 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf """ A hybrid model that interleaves Transformer and Mamba blocks. Right now only LlambaBlock is supported. - AS for the mixer, transformer uses MHA. For the LLlambaBlock we support Mamba1 and descrete mamba2. + As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig @@ -34,6 +34,51 @@ def __init__( self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) + def get_output_layers(self) -> list[Layer]: + """ + Get the output layers of the model. + This includes the language model head and any additional heads specified in the configuration. + """ + layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + + if self._config.prediction_heads > 1: + block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] + for i in range(1, self._config.prediction_heads): + if block_type == "t": + layers.append( + TransformerLayer( + self._config.transformer, + self._tensor_space, + layer_index=len(self._config.hybrid_block_layout), + return_input=i != self._config.prediction_heads - 1, + ) + ) + elif block_type == "m2": + mamba_block = self.SSM_BLOCK_CLS( + config_transformer=self._config.transformer, + config_ssm=self._config.ssm, + mixer_cls=DiscreteMamba2, + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + layers.append(mamba_block) + elif block_type == "m": + mamba_block = self.SSM_BLOCK_CLS( + config_transformer=self._config.transformer, + config_ssm=self._config.ssm, + mixer_cls=MambaLayer, + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + layers.append(mamba_block) + else: + raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) + + return layers + def get_layers(self) -> list[Layer]: """ Create a list of layers for the model, interleaving Transformer and Mamba blocks @@ -50,6 +95,9 @@ def get_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, layer_index=i + 1, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), ) ) elif block_type == "m2": @@ -59,9 +107,11 @@ def get_layers(self) -> list[Layer]: mixer_cls=DiscreteMamba2, layer_index=i + 1, tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), ) layers.append(mamba_block) - elif block_type == "m": # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( @@ -70,14 +120,16 @@ def get_layers(self) -> list[Layer]: mixer_cls=MambaLayer, layer_index=i + 1, tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), ) layers.append(mamba_block) - else: raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") - # Add the language model head - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)) + # Add the output layers + layers += self.get_output_layers() return layers diff --git a/tests/common.py b/tests/common.py index 5bd9563fa..569d690cc 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,6 +13,8 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -21,7 +23,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -419,3 +421,43 @@ def run_test_script( TEST_RESULTS_PATH / name / ARTIFACT_PATH, config, ) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), + ssm=SSMConfig(), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config diff --git a/tests/test_mtp.py b/tests/test_mtp.py new file mode 100644 index 000000000..918144fd0 --- /dev/null +++ b/tests/test_mtp.py @@ -0,0 +1,203 @@ +import typing + +import pytest +import torch + +from fast_llm.config import UpdateType +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.gpt.config import GPTBaseModelConfig +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.utils import Assert +from tests.common import get_hybrid_config, materialize_meta_tensors, requires_cuda + +try: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.models.ssm.model import HybridSSMBaseModel +except ImportError: + MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( + None, + None, + None, + ) + # Mamba not installed, skipping tests + + +run_hybrid_test = MambaLayer is not None and DiscreteMamba2 is not None and torch.cuda.is_available() + + +SEQUENCE_LENGTH = 200 +BATCH_SIZE = 4 +HIDDEN_SIZE = 256 +VOCAB_SIZE = 500 + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +@requires_cuda +@pytest.mark.parametrize( + "config_dict", + ( + {"prediction_heads": 1}, + {"prediction_heads": 2, "tie_word_embeddings": False}, + {"prediction_heads": 5, "tie_word_embeddings": False}, + ), +) +def test_transformer_mtp(config_dict: dict[str, typing.Any]): + config = GPTBaseModelConfig.from_dict( + { + "transformer": { + "hidden_size": HIDDEN_SIZE, + "num_layers": 2, + }, + "vocab_size": VOCAB_SIZE, + }, + config_dict, + update_type=UpdateType.update, + ) + distributed_config = DistributedConfig.from_dict({}) + distributed = Distributed(distributed_config) + model = GPTBaseModel(config, distributed_config) + model.setup(distributed) + materialize_meta_tensors(model, model._tensor_space) + model.to("cuda") + + sequence_first = config.sequence_first or ( + config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 + ) + target = torch.randint( + 0, + VOCAB_SIZE, + ( + (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + if sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + ), + dtype=torch.int64, + device=distributed.device, + ) + input_ = torch.randint( + 0, + VOCAB_SIZE, + (SEQUENCE_LENGTH, BATCH_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH), + device=distributed.device, + ) + attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) + position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda", dtype=torch.int64) + kwargs = { + "position_ids": position_ids, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.attention_mask: attention_mask, + TransformerKwargs.attention_mask_value: -100, + TransformerKwargs.grad_output: 1.0, + LanguageModelKwargs.labels: target, + } + if config.tie_word_embeddings: + kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight + else: + kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights + losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} + _ = model(input_, kwargs, losses=losses) + for loss_name, loss_values in losses.items(): + Assert.gt(len(loss_values), 0) + loss = sum( + [ + sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) + for i in range(model._config.prediction_heads) + ] + ) + loss.backward() + + +@requires_cuda +@pytest.mark.skipif(not run_hybrid_test, reason="No CUDA available or Mamba not installed") +@pytest.mark.parametrize( + ("hybrid_block_layout", "prediction_heads", "default_mtp_type"), + [ + (["m", "t"], 1, None), + (["t", "m"], 2, None), + (["m", "t"], 2, None), + (["t", "m2"], 3, None), + (["t", "m2"], 3, "m"), + ], +) +def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_heads, default_mtp_type): + hybrid_config = get_hybrid_config( + hybrid_block_layout=hybrid_block_layout, prediction_heads=prediction_heads, default_mtp_type=default_mtp_type + ) + model = HybridSSMBaseModel(hybrid_config, distributed_config) + distributed = Distributed(distributed_config) + model.setup(distributed) + tensor_space = model._tensor_space + materialize_meta_tensors(model, tensor_space) + model.to("cuda") + + num_heads, num_mtp_blocks = 0, 0 + str_block_mapping = {"t": TransformerLayer, "m": MambaLayer, "m2": DiscreteMamba2} + mtp_block_type = default_mtp_type or hybrid_block_layout[-1] + for block in model.get_output_layers(): + if isinstance(block, LanguageModelHead): + num_heads += 1 + else: + block = getattr(block, "mixer", block) + Assert.custom( + lambda _: isinstance(block, str_block_mapping[mtp_block_type]), + f"Block {block} is not of type {str_block_mapping[mtp_block_type]}", + ) + num_mtp_blocks += 1 + Assert.eq(num_heads, prediction_heads) + Assert.eq(num_mtp_blocks, prediction_heads - 1) + + batch_size = 2 + seq_length = 32 + x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") + position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) + attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape + labels = torch.randint(0, 49152, (batch_size, seq_length + model._config.prediction_heads - 1), device="cuda") + losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} + kwargs = { + "position_ids": position_ids, + TransformerKwargs.sequence_first: False, + TransformerKwargs.attention_mask: attention_mask, + TransformerKwargs.attention_mask_value: -100, + TransformerKwargs.grad_output: True, + LanguageModelKwargs.labels: labels, + } + + if model._config.tie_word_embeddings: + kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight + else: + kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights + + output = model( + x, + kwargs, + losses=losses, + ) + loss = sum( + [ + sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) + for i in range(model._config.prediction_heads) + ] + ) + loss.backward() diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 5863f9030..e6c9aafd1 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -13,25 +13,24 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from tests.common import get_hybrid_config, materialize_meta_tensors try: - from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMBaseModelConfig, HybridSSMModel + from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel except ImportError: - MambaLayer, LlambaBlock, HybridSSMBaseModel, HybridSSMBaseModelConfig, DiscreteMamba2 = ( - None, + MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( None, None, None, None, ) - # Mamba not isntalled, skipping tests + # Mamba not installed, skipping tests try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel @@ -41,30 +40,6 @@ run_test = MambaLayer is not None and torch.cuda.is_available() -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - @pytest.fixture def distributed_config(): return DistributedConfig( @@ -81,20 +56,6 @@ def distributed(distributed_config): return Distributed(config=distributed_config) -def get_hybrid_config(hybrid_block_layout=["t", "m", "t", "m"]): - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), - ssm=SSMConfig(), - hybrid_block_layout=hybrid_block_layout, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, - ) - return config - - def get_hf_llamba_out(input_ids, path, format): if format == LLambaHuggingfaceCheckpointFormat: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel @@ -185,7 +146,7 @@ def test_load_from_llamba_checkpoint(distributed_config): (["m", "t"], MambaLayer), (["m2", "t"], DiscreteMamba2), ], - ids=["mamba", "descrete_mamba2"], + ids=["mamba", "discrete_mamba2"], ) def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) @@ -248,12 +209,12 @@ def test_mamba_block(distributed_config, distributed): @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( - "hybrid_block_layout", + ("hybrid_block_layout"), [ (["m", "t"]), (["m2", "t"]), ], - ids=["mamba", "descrete_mamba2"], + ids=["mamba", "discrete_mamba2"], ) def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) @@ -275,7 +236,7 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo x, { "position_ids": position_ids, - TransformerKwargs.sequence_first: True, + TransformerKwargs.sequence_first: False, TransformerKwargs.attention_mask: attention_mask, TransformerKwargs.attention_mask_value: -100, TransformerKwargs.grad_output: True,