Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
config: SSMConfig,
layer_idx: int,
tensor_space: TensorSpace,
return_input: bool = False,
):
"""
See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args.
Expand All @@ -42,6 +43,7 @@ def __init__(
self.config: SSMConfig = config
bias = config.add_bias_linear
self.layer_idx = layer_idx
self._return_input = return_input

td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim)
td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim)
Expand Down Expand Up @@ -122,10 +124,10 @@ def forward(self, hidden_states, kwargs):
outputs["hidden_states"]: (B, L, D).
outputs["state"]: inference cache.
"""
u = hidden_states
input_ = hidden_states
outputs = {}
# assert state is None
batch, seqlen, dim = u.shape
batch, seqlen, dim = input_.shape

state = None

Expand All @@ -134,7 +136,7 @@ def forward(self, hidden_states, kwargs):

# Pad input to nearest multiple of chunklen
padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size
u = torch.nn.functional.pad(u, (0, 0, 0, padded_len - seqlen))
u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen))

# Project input
xBCzA_log = self.in_proj(u)
Expand Down Expand Up @@ -198,10 +200,13 @@ def forward(self, hidden_states, kwargs):

# Norm and gate
out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias))
outputs["hidden_states"] = out[:, :seqlen, :]
outputs["hidden_states"] = out[:, :seqlen, :].contiguous()

if self._return_input:
return torch.stack([input_, outputs["hidden_states"]], dim=0)

# TODO: since we do not support inference for now, we only return the hidden states for now.
return outputs["hidden_states"].contiguous(), None
return outputs["hidden_states"], None

def convolutional_forward(self, xBC, padded_len):
"""Convolutional layer forward pass for the full sequence."""
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/layers/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
config: SSMConfig,
layer_idx: int,
tensor_space: TensorSpace,
return_input: bool = False,
):
factory_kwargs = {}
super().__init__()
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
**factory_kwargs,
)
self.out_proj.weight.auto_grad_accumulation = True
self._return_input = return_input

def forward(self, hidden_states, kwargs):
batch, seqlen, dim = hidden_states.shape
Expand Down Expand Up @@ -170,4 +172,6 @@ def forward(self, hidden_states, kwargs):
delta_bias=self.dt_proj_bias.float(),
delta_softplus=True,
)
if self._return_input:
out = torch.stack((hidden_states, out), dim=0)
return out, None
23 changes: 16 additions & 7 deletions fast_llm/models/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig):
)
hybrid_block_layout: list[str] = Field(
default_factory=lambda: ["m2"],
desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.",
desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2",
hint=FieldHint.architecture,
)
default_mtp_type: str | None = Field(
default=None,
desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.",
hint=FieldHint.optional,
)
use_megatron_initialization: bool = Field(
default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing
) # TODO: is this needed?
Expand Down Expand Up @@ -66,7 +71,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension))
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2))

if "m2" in self.hybrid_block_layout:
if "m2" in self.hybrid_block_layout or self.default_mtp_type == "m2":
# Mamba2 specific dimensions
# as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4
headdim = d_inner // self.ssm.n_v_heads
Expand All @@ -84,23 +89,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim))

def _validate(self):
if len(self.hybrid_block_layout) != self.transformer.num_layers:
len_block_layout = len(self.hybrid_block_layout)
len_block_layout = len(self.hybrid_block_layout)
if len_block_layout != self.transformer.num_layers:
if self.transformer.num_layers % len_block_layout != 0:
raise ValueError(
f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}"
f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}"
)
num_repeats = int(self.transformer.num_layers // len_block_layout)
logger.warning(
f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times"
f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times"
)
self.hybrid_block_layout = self.hybrid_block_layout * num_repeats

Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers)
Assert.eq(len_block_layout, self.transformer.num_layers)
Assert.custom(
lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout),
f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'",
)
Assert.custom(
lambda _: self.default_mtp_type in ["t", "m", "m2", None],
f"Invalid MTP type: {self.default_mtp_type}. Must be 't' or 'm' or 'm2' or None",
)

super()._validate()

Expand Down
62 changes: 57 additions & 5 deletions fast_llm/models/ssm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf
"""
A hybrid model that interleaves Transformer and Mamba blocks.
Right now only LlambaBlock is supported.
AS for the mixer, transformer uses MHA. For the LLlambaBlock we support Mamba1 and descrete mamba2.
As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2.
"""

config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig
Expand All @@ -34,6 +34,51 @@ def __init__(
self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed
super().__init__(config, distributed_config)

def get_output_layers(self) -> list[Layer]:
"""
Get the output layers of the model.
This includes the language model head and any additional heads specified in the configuration.
"""
layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)]

if self._config.prediction_heads > 1:
block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1]
for i in range(1, self._config.prediction_heads):
if block_type == "t":
layers.append(
TransformerLayer(
self._config.transformer,
self._tensor_space,
layer_index=len(self._config.hybrid_block_layout),
return_input=i != self._config.prediction_heads - 1,
)
)
elif block_type == "m2":
mamba_block = self.SSM_BLOCK_CLS(
config_transformer=self._config.transformer,
config_ssm=self._config.ssm,
mixer_cls=DiscreteMamba2,
layer_index=len(self._config.hybrid_block_layout),
tensor_space=self._tensor_space,
return_input=i != self._config.prediction_heads - 1,
)
layers.append(mamba_block)
elif block_type == "m":
mamba_block = self.SSM_BLOCK_CLS(
config_transformer=self._config.transformer,
config_ssm=self._config.ssm,
mixer_cls=MambaLayer,
layer_index=len(self._config.hybrid_block_layout),
tensor_space=self._tensor_space,
return_input=i != self._config.prediction_heads - 1,
)
layers.append(mamba_block)
else:
raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'")
layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i))

return layers

def get_layers(self) -> list[Layer]:
"""
Create a list of layers for the model, interleaving Transformer and Mamba blocks
Expand All @@ -50,6 +95,9 @@ def get_layers(self) -> list[Layer]:
self._config.transformer,
self._tensor_space,
layer_index=i + 1,
return_input=(
i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1
),
)
)
elif block_type == "m2":
Expand All @@ -59,9 +107,11 @@ def get_layers(self) -> list[Layer]:
mixer_cls=DiscreteMamba2,
layer_index=i + 1,
tensor_space=self._tensor_space,
return_input=(
i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1
),
)
layers.append(mamba_block)

elif block_type == "m":
# Create Mamba block
mamba_block = self.SSM_BLOCK_CLS(
Expand All @@ -70,14 +120,16 @@ def get_layers(self) -> list[Layer]:
mixer_cls=MambaLayer,
layer_index=i + 1,
tensor_space=self._tensor_space,
return_input=(
i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1
),
)
layers.append(mamba_block)

else:
raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'")

# Add the language model head
layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=0))
# Add the output layers
layers += self.get_output_layers()

return layers

Expand Down
44 changes: 43 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.layers.ssm.config import SSMConfig
from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.models.gpt.config import (
LlamaGPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
Expand All @@ -21,7 +23,7 @@
Qwen2GPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
)
from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat
from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat
from fast_llm.tools.train import CliTrainingConfig
from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs

Expand Down Expand Up @@ -419,3 +421,43 @@ def run_test_script(
TEST_RESULTS_PATH / name / ARTIFACT_PATH,
config,
)


def materialize_meta_tensors(model, tensor_space):
# Materialize parameters that are on meta device
for name, param in model.named_parameters():
if param.device.type == "meta":
# Check if the parameter is a custom tensor type
if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"):
param_data = param.new_empty(param.shape, device="cuda")
# Initialize param_data
param.init_parameter(param_data, tensor_space.distributed)
# Replace the parameter in the module
module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name)
module = model
if module_path is not None:
for part in module_path.split("."):
module = getattr(module, part)
param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad)
# TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation
param.grad = None
param.grad_buffer = torch.empty_like(param)
param.param_grad_is_zero = True
module._parameters[param_name] = param
return model


def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None):
config = HybridSSMBaseModelConfig(
transformer=TransformerConfig(num_layers=len(hybrid_block_layout)),
ssm=SSMConfig(),
hybrid_block_layout=hybrid_block_layout,
prediction_heads=prediction_heads,
default_mtp_type=default_mtp_type,
init_method_std_embed=0.02,
init_method_min_embed=-0.02,
init_method_max_embed=0.02,
use_position_embeddings=True,
tie_word_embeddings=False,
)
return config
Loading
Loading