From 513775730095c5e53e4d982139a250aba9c99b52 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 00:10:47 -0400 Subject: [PATCH 001/122] stuff --- fast_llm/config.py | 177 ++++++++++++++++++----- fast_llm/data/data/config.py | 11 +- fast_llm/data/data/gpt/config.py | 12 +- fast_llm/data/dataset/config.py | 26 ++-- fast_llm/data/dataset/gpt/config.py | 13 +- fast_llm/engine/checkpoint/config.py | 5 +- fast_llm/engine/checkpoint/external.py | 4 +- fast_llm/engine/distributed/config.py | 3 +- fast_llm/engine/schedule/config.py | 7 +- fast_llm/engine/training/config.py | 4 +- fast_llm/layers/language_model/config.py | 23 +-- fast_llm/layers/transformer/config.py | 94 ++++++------ fast_llm/profile.py | 4 +- fast_llm/utils.py | 34 ----- tests/data/common.py | 7 +- 15 files changed, 241 insertions(+), 183 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c889658..326845f0a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import enum import logging @@ -9,7 +10,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert, Tag, get_type_name, header, log logger = logging.getLogger(__name__) @@ -43,6 +44,13 @@ class _ConfigDictFormat(str, enum.Enum): tuple = "tuple" +class UpdateType(str, enum.Enum): + # Override entries no matter what they contais. + override = "override" + # Override atomic entries and lists, but update dicts recursively by setting or overriding only the specified entries. + update = "update" + + class FieldHint: """ A label defined for each config field, to let the user and some methods know how important each field is. @@ -125,6 +133,9 @@ def __init__( # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + # Option to skip (postpone) instantiation of a `Config` field. + # Note: The config still needs to be instantiated for validation to succeed. + # auto_instantiate: bool = True, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, @@ -152,6 +163,7 @@ def __init__( self.doc = doc self.hint = hint self.valid = valid + # self.auto_instantiate = auto_instantiate class FieldUpdate(dict): @@ -254,7 +266,16 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - return _process_config_class(dataclasses.dataclass(cls)) + wrapped = _process_config_class(dataclasses.dataclass(cls)) + + wrapped_init = cls.__init__ + + def __init__(self, **kwargs): + wrapped_init(self, **kwargs) + self._explicit_fields = set(kwargs) + + cls.__init__ = __init__ + return wrapped # See if we're being called as @config_class or @config_class(). if cls is None: @@ -277,9 +298,17 @@ class Config: # We can't use @config_class on this one because it needs this class to be defined, so we assume this one is OK. __class_validated__: typing.ClassVar[bool] = True + # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False + # Keep track of whether an instance has been validated _validated: bool = Field(init=False, repr=False) + # Keep track of unknown fields so they can be reported during validation. _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + # Keep track of explicitly set fields to ensure they get serialized and used as config updates. + _explicit_fields: set[str] = Field(init=False, repr=False) + # Used within `_set_implicit_default` to set implicit defaults for fields + # without them being automatically added to `_explicit_fields`. + _setting_implicit_default: bool = Field(init=False, repr=False) def __post_init__(self): """ @@ -288,6 +317,7 @@ def __post_init__(self): and all post-processing should be done in `_validate` """ self._validated = False + self._setting_implicit_default = False if _AUTO_VALIDATE: self.validate() @@ -305,6 +335,12 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) + elif not getattr(self, "_setting_implicit_default", True): + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context + # and during dataclass initialization (`_setting_implicit_default` not yet set). + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -318,6 +354,12 @@ def __delattr__(self, key: str) -> None: ) super().__delattr__(key) + @contextlib.contextmanager + def _set_implicit_default(self): + self._setting_implicit_default = True + yield + self._setting_implicit_default = False + def validate[T](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only @@ -332,6 +374,7 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True + print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -344,16 +387,17 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa - continue - value = getattr(self, name) - if value is DEFAULT: - # Replace the value with its default. - # We still need to validate because some fields have invalid defaults. - value = field.default - new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) - setattr(self, name, new_value) + with self._set_implicit_default(): + for name, field in self.fields(): + if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + continue + value = getattr(self, name) + if value is DEFAULT: + # Replace the value with its default. + # We still need to validate because some fields have invalid defaults. + value = field.default + new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) + setattr(self, name, new_value) for name in getattr(self, "_unknown_fields", {}): errors.append(f"Unknown field `{name}` in class {self._get_class_name()}") if errors: @@ -555,9 +599,8 @@ def _to_dict( return arg_dict - @classmethod def _add_field_to_args( - cls, + self, args: dict | list, name: str | None, field: Field | None, @@ -574,46 +617,48 @@ def _add_field_to_args( ): # Exclude class variables and derived fields unless requested explicitly. return - elif isinstance(value, Config): + explicit_field = ( + field is None + or name in self._explicit_fields + or (verbose is not None and verbose >= FieldHintImportance[field.hint]) + ) + if isinstance(value, Config): field_value = value._to_dict( verbose=verbose, all_fields=all_fields, format_=format_, serializable=serializable, ) + # Empty configs can safely be trimmed. + explicit_field = all_fields elif isinstance(value, (list, tuple, set)): field_value = {} if format_ == _ConfigDictFormat.tuple else [] for i, list_value in enumerate(value): - cls._add_field_to_args( + self._add_field_to_args( field_value, str(i), None, list_value, verbose, all_fields, format_, serializable ) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - cls._add_field_to_args( + self._add_field_to_args( field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable ) - elif ( - verbose is not None - and field is not None - and FieldHintImportance[field.hint] > verbose - and value == field.default - ): - # Exclude unimportant default values. - return - else: + elif explicit_field: field_value = value if serializable: - field_value = cls._serialize_value(value) + field_value = self._serialize_value(value) if format_ == _ConfigDictFormat.tuple: field_value = {(): field_value} + else: + # Exclude unimportant (implicit or explicit) default values. + return if serializable: - name = cls._serialize_value(name) + name = self._serialize_value(name) if format_ == _ConfigDictFormat.tuple: args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or all_fields: + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: if isinstance(args, dict): args[name] = field_value else: @@ -671,6 +716,7 @@ def from_dict( default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, + update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() @@ -678,7 +724,7 @@ def from_dict( if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) for keys, value in update.items(): - set_nested_dict_value(default, keys, value) + set_nested_dict_value(default, keys, value, update_type) return cls._from_dict(default, strict) @@ -712,10 +758,7 @@ def _from_dict( continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): - if flat: - out_arg_dict[name] = field.type._from_dict(default, False, True) - else: - out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict) + out_arg_dict[name] = field.type._from_dict(default, False, True) elif name in default: out_arg_dict[name] = default.pop(name) else: @@ -916,3 +959,69 @@ def __init__(self, config: ConfigType, *args, **kwargs): @property def config(self) -> ConfigType: return self._config + + +def set_nested_dict_value[ + KeyType, ValueType +]( + d: dict[KeyType, ValueType], + keys: KeyType | tuple[KeyType, ...], + value: ValueType, + update_type: UpdateType = UpdateType.override, +) -> None: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d.setdefault(key, {}) + assert isinstance(d, dict) + key = keys[-1] + else: + key = keys + if update_type == UpdateType.override: + d[key] = value + elif update_type == UpdateType.update: + # TODO: Improve error messages, ex. for nested cases? + if isinstance(d[key], Config): + raise ValueError("Cannot update an already instantiated config.") + elif isinstance(value, Config): + raise ValueError("Cannot update a config dict with an already instantiated config.") + elif isinstance(d, dict): + if key in d: + Assert.custom(isinstance, d[key], dict) + else: + d[key] = {} + for key_, value_ in value.items(): + set_nested_dict_value(d, key_, value_, update_type) + elif ( + isinstance(value, (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) + ) or ( + isinstance(d[key], (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) + ): + raise ValueError("Update not supported for nested lists.") + else: + d[key] = value + else: + raise NotImplementedError(update_type) + + +def get_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys: + d = d[key] + return d + else: + return d[keys] + + +def pop_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d[key] + return d.pop(keys[-1]) + else: + return d.pop(keys) diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 752fdfd17..25850ac3f 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -1,18 +1,9 @@ import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Config, Field, config_class from fast_llm.data.dataset.config import SamplingConfig, SamplingData -@config_class() -class SamplingDefaultConfig(SamplingConfig): - seed: int = FieldUpdate( - default=784569, - desc="Seed for random sampling.", - hint=FieldHint.feature, - ) - - @config_class() class DataConfig(Config): _abstract = True diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index cbbfa036d..d1d6bd404 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -2,13 +2,12 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig -from fast_llm.data.data.config import DataConfig, SamplingDefaultConfig +from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.gpt.config import ( GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig, GPTSamplingConfig, - ShufflingType, ) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert @@ -16,13 +15,6 @@ logger = logging.getLogger(__name__) -@config_class() -class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): - gpu: bool = FieldUpdate(default=True) - use_loss_masking_spans: bool = FieldUpdate(default=False) - shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) - - @config_class() class GPTDataConfig(DataConfig, GPTLegacyConfig): """ @@ -44,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingDefaultConfig = FieldUpdate(default_factory=GPTSamplingDefaultConfig) + sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 431a28a07..7808158be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -5,7 +5,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities @@ -17,20 +17,12 @@ @config_class() class SamplingConfig(Config): - seed: int | None = Field( - default=None, + seed: int = Field( + default=784569, desc="Seed for random sampling.", hint=FieldHint.feature, ) - @property - def updates(self) -> dict[str, typing.Any]: - return { - key: value - for key, value in self.to_serialized(verbose=FieldVerboseLevel.everything).items() - if value is not None - } - @dataclasses.dataclass(kw_only=True) class SamplingData: @@ -44,10 +36,10 @@ class SamplingData: # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count - def update(self, config: SamplingConfig, **kwargs): - if config_updates := config.updates: - kwargs["config"] = self.config.to_copy(config_updates) - return dataclasses.replace(self, **kwargs) if kwargs else self + def update_config(self, update: SamplingConfig): + return dataclasses.replace( + self, config=self.config.from_dict(self.config, update, update_type=UpdateType.update) + ) def get_next_rank(self) -> int: # Counter that loops over ranks to try to distribute workloads evenly between ranks. @@ -163,7 +155,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. """ - _abstract = False + _abstract = True sampling: SamplingConfig = Field( default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", @@ -176,7 +168,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): ) def build_and_sample(self, data: SamplingData) -> SampledDataset: - return self.dataset.build_and_sample(data.update(self.sampling)) + return self.dataset.build_and_sample(data.update_config(self.sampling)) @config_class() diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 74d8a0c35..118b3039d 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -45,20 +45,20 @@ class ShufflingType(str, enum.Enum): @config_class() class GPTSamplingConfig(SamplingConfig): - gpu: bool | None = Field( - default=None, + gpu: bool = Field( + default=True, desc="Enable fast sampling on GPU." " Note that random sampling works differently on GPU," " so the sample won't match the CPU equivalent.", hint=FieldHint.feature, ) - use_loss_masking_spans: bool | None = Field( - default=None, + use_loss_masking_spans: bool = Field( + default=False, desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - shuffle: ShufflingType | None = Field( - default=None, + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, desc="Shuffling strategy.", hint=FieldHint.feature, ) @@ -210,6 +210,7 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): + _abstract = False type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 92f1165d4..46c8f483b 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -164,8 +164,9 @@ class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateCon def _validate(self) -> None: if self.optimizer_state is None: - # TODO: Make sure it's a type - self.optimizer_state = self.format.support_optimizer + with self._set_implicit_default(): + # TODO: Make sure it's a type + self.optimizer_state = self.format.support_optimizer super()._validate() if self.optimizer_state: assert self.format.support_optimizer diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 83514c86e..76f5e336f 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -7,14 +7,14 @@ import torch from fast_llm import __version__ -from fast_llm.config import MISSING +from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 1b3e73bb6..76c496ac9 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -279,7 +279,8 @@ def _validate(self) -> None: self.tensor_rank = self.rank % self.tensor_parallel if self.tensor_parallel == 1: - self.sequence_tensor_parallel = False + with self._set_implicit_default(): + self.sequence_tensor_parallel = False self.distributed_dims = {} diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a3..91256deb8 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -79,10 +79,6 @@ def setup(self, distributed_config: DistributedConfig) -> None: def num_inputs(self) -> int: return self.sequential_micro_batches * self.num_micro_sequences - @property - def _is_setup(self) -> bool: - return hasattr(self, "_distributed") - def _validate(self) -> None: # Use the distributed properties to determine the batch size and its breakdown. # Requires post-processed distributed config args @@ -133,7 +129,8 @@ def _validate(self) -> None: " Use at your own risk." ) if self.micro_sequence_length is None: - self.micro_sequence_length = self.sequence_length + with self._set_implicit_default(): + self.micro_sequence_length = self.sequence_length self.num_micro_sequences = div(self.sequence_length, self.micro_sequence_length) super()._validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 30add2f4b..3a65bbc9e 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -42,7 +42,8 @@ class IntervalConfig(Config): def _validate(self) -> None: if self.interval: - self.offset %= self.interval + with self._set_implicit_default(): + self.offset %= self.interval super()._validate() def enabled(self, iteration: int | None = None) -> bool: @@ -109,6 +110,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) + post_alerts: bool = Field(init=False, repr=False) def _validate(self) -> None: if self.status_updates is None: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467cc..fa5d49201 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -60,7 +60,8 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled + with self._set_implicit_default(): + self.use_position_embeddings = not self.transformer.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @@ -175,14 +176,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.transformer.init_method_std is None: - self.transformer.init_method_std = self.transformer.hidden_size**-0.5 - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) + self.transformer.validate() + with self._set_implicit_default(): + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1352c7f05..139831372 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -250,12 +250,13 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): ) def _validate(self) -> None: - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu self.projection_size = self.num_attention_heads * self.kv_channels self.num_unshared_experts = self.num_experts - self.num_shared_experts @@ -569,46 +570,47 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + with self._set_implicit_default(): + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: + self.mlp_lr_scale = [None] + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) diff --git a/fast_llm/profile.py b/fast_llm/profile.py index a0fc3946a..a3902cf1e 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -94,7 +94,9 @@ def _validate(self) -> None: self.global_attention_layers = set() profile_ranks = set(self.ranks or []) Assert.eq(len(profile_ranks), len(self.ranks or [])) - self.ranks = profile_ranks # noqa + with self._set_implicit_default(): + self.ranks = profile_ranks # noqa + super()._validate() def get_profiler( self, *, distributed_config: DistributedConfig | None = None, start_step: int = 0 diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d650fa94f..b0e482311 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -249,40 +249,6 @@ def normalize_probabilities(p: "npt.ArrayLike", return_array: bool = False) -> " return out if return_array else out.tolist() -def set_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...], value: ValueType) -> None: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d.setdefault(key, {}) - assert isinstance(d, dict) - d[keys[-1]] = value - else: - d[keys] = value - - -def get_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys: - d = d[key] - return d - else: - return d[keys] - - -def pop_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d[key] - return d.pop(keys[-1]) - else: - return d.pop(keys) - - class InvalidObject: """ Store an error and raise it if accessed. diff --git a/tests/data/common.py b/tests/data/common.py index 917b4914c..bdfd54a7c 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -5,12 +5,13 @@ import torch from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig, GPTSamplingDefaultConfig +from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import ( GPTIndexedDatasetConfig, GPTSampledDatasetConfig, + GPTSamplingConfig, GPTSamplingData, ShufflingType, ) @@ -39,7 +40,7 @@ def get_sampling_data( ) -> GPTSamplingData: # Config with convenient defaults. return GPTSamplingData( - config=GPTSamplingDefaultConfig( + config=GPTSamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -76,7 +77,7 @@ def get_test_data_and_compare_samples( distributed_config = DistributedConfig(seed=seed if legacy else 87522) distributed = Distributed(distributed_config, use_cpu=True) assert "sampling" not in config - config["sampling"] = GPTSamplingDefaultConfig( + config["sampling"] = GPTSamplingConfig( seed=87522 if legacy else seed, gpu=gpu, shuffle=shuffle, From f26010ef9f8cfd070734751f9dec45a364496308 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:21:45 -0400 Subject: [PATCH 002/122] Update pretrained config --- fast_llm/config.py | 5 +-- fast_llm/engine/checkpoint/config.py | 1 + fast_llm/engine/checkpoint/distributed.py | 6 +-- fast_llm/engine/huggingface/config.py | 5 +-- fast_llm/engine/huggingface/model.py | 8 ++-- fast_llm/engine/multi_stage/config.py | 42 +++---------------- fast_llm/engine/multi_stage/fast_llm_model.py | 7 ++-- fast_llm/layers/transformer/config.py | 17 ++++---- fast_llm/layers/transformer/mlp.py | 4 +- 9 files changed, 28 insertions(+), 67 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0a..5436a2947 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 46c8f483b..621f7fe8d 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,6 +200,7 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): + # TODO!!!!!!! _abstract = False load_config: ModelConfigType = Field( diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 9c171befa..a920a52c2 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,7 +13,6 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, - ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -43,15 +42,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: # TODO: More safety checks - loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) - loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) num_shards = self.get_num_shards(config) shard_names = self.get_shard_names(config) Assert.eq(metadata.shards[:num_shards], list(shard_names)) same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) + type(metadata.config) == type(self._model.config) and config.optimizer_state + and metadata.config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) ) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28e..e79857c91 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -73,10 +73,7 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_metadata( - pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates - ) - + fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index 499f0af12..e4f2cd999 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -73,15 +73,13 @@ def from_pretrained( format=FastLLMCheckpointFormat, ) - config_updates = {} + updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: - config_updates[("distributed", "training_dtype")] = torch_dtype + updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained( - pretrained_model_name_or_path, config_updates=config_updates, mode=mode - ) + fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d6997105c..d8333c9ba 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -246,46 +246,12 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( - cls, - pretrained: CheckpointLoadMetadataConfig, - default: typing.Self | None = None, + cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: # TODO: Add *updates? assert pretrained.path is not None metadata = cls.load_metadata(pretrained) - return cls.from_metadata(pretrained, metadata, default) - - @classmethod - def from_metadata( - cls, - pretrained: CheckpointLoadMetadataConfig, - metadata: "CheckpointMetadata", - default: typing.Self | None = None, - updates: dict[str | tuple[str, ...], typing.Any] | None = None, - ) -> typing.Self: - # TODO: Standardize to *updates? - # TODO v0.3: Update, remove support for older checkpoints. - if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2): - raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}") - pretrained_config = cls.from_dict(metadata.config) - if not pretrained.load_config.load_architecture: - assert default is not None - config = default.to_copy() - config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_config.load_fast_llm: - config = pretrained_config - else: - with NoAutoValidate(): - config = cls() if default is None else default.to_copy() - if pretrained.load_config.load_base_model: - config.base_model = pretrained_config.base_model - else: - config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) - config.validate() - - if updates: - config = config.to_copy(updates) - return config + return cls.from_dict(metadata.config, *updates) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -328,7 +294,7 @@ def _validate(self) -> None: self.pretrained.setup(self.model) self.pretrained.validate() if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + self.model = self.model.from_pretrained(self.pretrained, self.model) self._setup() super()._validate() @@ -380,6 +346,8 @@ def _validate(self) -> None: self.format = self.model.get_checkpoint_format(self.format) super()._validate() + if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2): + raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}") Assert.eq(self.config.__class__, self.model) @classmethod diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index b268ec29e..22e5ccaca 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,6 +1,7 @@ import logging import typing +from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.distributed.distributed import Distributed @@ -45,9 +46,7 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] def from_pretrained( cls, pretrained_config: CheckpointLoadConfig, - default_config: FastLLMModelConfig = None, - *, - config_updates: dict[str | tuple[str, ...], typing.Any] | None = None, + *updates: dict[str | tuple[str, ...], typing.Any], optimizer_state_names: tuple[str, ...] | None = None, setup: bool = True, mode: StageMode = StageMode.training, @@ -55,7 +54,7 @@ def from_pretrained( stage_filter: set | None = None, ) -> typing.Self: metadata = cls.config_class.load_metadata(pretrained_config) - config = cls.config_class.from_metadata(pretrained_config, metadata, default_config, config_updates) + config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: # TODO v0.3: Make metadata.shards mandatory? if metadata.shards: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 139831372..9410157b5 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -532,8 +532,8 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: list[float | None] = Field( - default_factory=list, + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, @@ -581,8 +581,6 @@ def _validate(self) -> None: self.init_method_std_mlp_1 = self.init_method_std if self.init_method_std_mlp_2 is None: self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: @@ -614,10 +612,13 @@ def _validate(self) -> None: super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) - Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242dc..ff4eaf268 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -45,7 +45,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) From b930a391b37703e7dce23fdb544b08fe98d42084 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:27:40 -0400 Subject: [PATCH 003/122] stuff --- fast_llm/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0a..5436a2947 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: From 8117c47b483c26853bf5015ef85b4e94472de1b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:40:37 -0400 Subject: [PATCH 004/122] fixes --- fast_llm/engine/multi_stage/config.py | 7 +++---- fast_llm/engine/multi_stage/fast_llm_model.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d8333c9ba..342a453b4 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -10,6 +10,7 @@ Field, FieldHint, NoAutoValidate, + UpdateType, ValidationError, check_field, config_class, @@ -248,13 +249,11 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: def from_pretrained( cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: - # TODO: Add *updates? - assert pretrained.path is not None - metadata = cls.load_metadata(pretrained) - return cls.from_dict(metadata.config, *updates) + return cls.from_dict(cls.load_metadata(pretrained).config, *updates, update_type=UpdateType.update) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + assert config.path is not None with NoAutoValidate(): metadata = config.format.get_handler_class().load_metadata(config) try: diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 22e5ccaca..2dec7959b 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -36,11 +36,11 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - fast_llm_metadata = self.config_class.load_metadata(config) + metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, fast_llm_metadata) + converter.load(config, metadata) self._finalize_load(reset_optimizer=not config.optimizer_state) - return fast_llm_metadata.metadata + return metadata.metadata @classmethod def from_pretrained( From 1c995d3e76be57ec80f9f305d83b613e0c8bdba3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:50:00 -0400 Subject: [PATCH 005/122] fix --- fast_llm/engine/checkpoint/config.py | 16 ---------------- fast_llm/engine/checkpoint/distributed.py | 6 +++--- fast_llm/engine/checkpoint/huggingface.py | 2 +- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 621f7fe8d..a34725236 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,24 +200,8 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): - # TODO!!!!!!! _abstract = False - load_config: ModelConfigType = Field( - default=ModelConfigType.architecture, - desc="Configuration to save/load.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - super()._validate() - if self.format.enforce_architecture_match: - assert self.load_config.load_architecture - - @property - def compare_log_fn(self): - return ValueError if self.load_config.load_architecture else logger.warning - @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index a920a52c2..953cdef88 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -67,11 +67,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) else: log_main_rank("Checkpoint format doesn't match, using safe load") - self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context: - for rank in range(loaded_config.distributed.world_size): + for rank in range(metadata.config.distributed.world_size): loaded_model = self._model.__class__( - loaded_config.to_copy({("distributed", "rank"): rank}), + metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4e..d45336639 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -41,7 +41,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: assert not config.optimizer_state - self._model.config.base_model.compare_architecture(metadata.config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) super().load(config, metadata) @classmethod From 506fe92917b28fc2d865edf69bad9827c5f92bfa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 16:04:35 -0400 Subject: [PATCH 006/122] fixes --- fast_llm/config.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 5436a2947..222a3ec71 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -266,13 +266,21 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - wrapped = _process_config_class(dataclasses.dataclass(cls)) + if hasattr(cls, "__post_init__"): + raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") + + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) wrapped_init = cls.__init__ def __init__(self, **kwargs): + # This is similar to `__post_init__`, but has access to the list of arguments passed to `__init__`. wrapped_init(self, **kwargs) self._explicit_fields = set(kwargs) + self._validated = False + self._setting_implicit_default = False + if _AUTO_VALIDATE: + self.validate() cls.__init__ = __init__ return wrapped @@ -310,17 +318,6 @@ class Config: # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool = Field(init=False, repr=False) - def __post_init__(self): - """ - Perform validation unless prevented with `NoAutoValidate`. - In general this should not be overridden in derived classes, - and all post-processing should be done in `_validate` - """ - self._validated = False - self._setting_implicit_default = False - if _AUTO_VALIDATE: - self.validate() - def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -983,13 +980,15 @@ def set_nested_dict_value[ raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") - elif isinstance(d, dict): + elif isinstance(value, dict): if key in d: Assert.custom(isinstance, d[key], dict) else: d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) + elif isinstance(d[key], dict): + raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) From 971d3ef23297f7dd64550facff25f8609c0fb097 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 18:32:07 -0400 Subject: [PATCH 007/122] fixes --- fast_llm/config.py | 14 +++++++++----- fast_llm/engine/huggingface/config.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 222a3ec71..7cb54919d 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -90,12 +90,12 @@ class FieldHint: class FieldVerboseLevel: - nothing = -1 + explicit = None core = 0 optional = 10 performance = 20 debug = 50 - everything = None + everything = 2**31 FieldHintDoc = { @@ -680,7 +680,7 @@ def to_copy[ ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: return self.from_dict(self, *updates, strict=strict) - def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: + def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) def to_logs[ @@ -863,8 +863,12 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) - other_dict = other._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) + self_dict = self._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) + other_dict = other._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) compare = { key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) for key in self_dict.keys() | other_dict.keys() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28e..2b240e4b4 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,6 +5,7 @@ import transformers +from fast_llm.config import FieldVerboseLevel from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig, FastLLMCheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig @@ -90,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=None) + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized() + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: From 6bf20cb2d72faabbf5eb6eea4de4f46180f836f8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 21:26:59 -0400 Subject: [PATCH 008/122] Tests wip --- fast_llm/config.py | 40 ++++-- tests/config/__init__.py | 0 tests/config/common.py | 37 ++++++ tests/config/test_field.py | 176 ++++++++++++++++++++++++++ tests/data/test_dataset_from_file.py | 1 - tests/data/test_prepare_gpt_memmap.py | 1 - tests/test_config.py | 58 +++------ 7 files changed, 258 insertions(+), 55 deletions(-) create mode 100644 tests/config/__init__.py create mode 100644 tests/config/common.py create mode 100644 tests/config/test_field.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 7cb54919d..67aa5b7a6 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,4 +1,5 @@ import contextlib +import copy import dataclasses import enum import logging @@ -316,7 +317,7 @@ class Config: _explicit_fields: set[str] = Field(init=False, repr=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False, repr=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -332,12 +333,20 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) - elif not getattr(self, "_setting_implicit_default", True): - field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: - # Adding to explicit field list except within `_set_implicit_default` context - # and during dataclass initialization (`_setting_implicit_default` not yet set). - self._explicit_fields.add(key) + if getattr(self, "_setting_implicit_default", None) is not None: + if self._setting_implicit_default: + if key in self._explicit_fields: + raise RuntimeError( + f"Trying to set an implicit default for field `{key}`," + f"but the field has already been set explicitly." + ) + else: + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context, + # during dataclass initialization (`_setting_implicit_default` not yet set) + # and during automated config validation (`_setting_implicit_default=None`) + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -352,8 +361,9 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self): - self._setting_implicit_default = True + def _set_implicit_default(self, _value: bool | int = True): + assert self._setting_implicit_default is False + self._setting_implicit_default = _value yield self._setting_implicit_default = False @@ -383,7 +393,7 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - with self._set_implicit_default(): + with self._set_implicit_default(None): for name, field in self.fields(): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue @@ -567,7 +577,7 @@ def get_field(cls, name: str) -> Field: def _to_dict( self, - verbose: int | None = None, + verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, format_: _ConfigDictFormat = _ConfigDictFormat.nested, serializable: bool = False, @@ -716,6 +726,8 @@ def from_dict( ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() + else: + default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) @@ -980,7 +992,7 @@ def set_nested_dict_value[ d[key] = value elif update_type == UpdateType.update: # TODO: Improve error messages, ex. for nested cases? - if isinstance(d[key], Config): + if isinstance(d.get(key), Config): raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") @@ -991,13 +1003,13 @@ def set_nested_dict_value[ d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) - elif isinstance(d[key], dict): + elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) ) or ( - isinstance(d[key], (list, set, tuple)) + isinstance(d.get(key), (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) ): raise ValueError("Update not supported for nested lists.") diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/config/common.py b/tests/config/common.py new file mode 100644 index 000000000..3109175aa --- /dev/null +++ b/tests/config/common.py @@ -0,0 +1,37 @@ +import enum +import pathlib + +from fast_llm.config import Config, Field, FieldHint, config_class + + +class TestEnum(str, enum.Enum): + a = "a" + b = "b" + c = "c" + + +@config_class +class TestConfig(Config): + int_field: int = Field(default=0, hint=FieldHint.optional) + bool_field: bool = Field(default=False, hint=FieldHint.optional) + str_field: str = Field(default="", hint=FieldHint.optional) + path_field: pathlib.Path = Field(default="", hint=FieldHint.optional) + float_field: float = Field(default=4.0, hint=FieldHint.optional) + optional_field: str | None = Field(default=None, hint=FieldHint.optional) + union_field: str | int = Field(default=7, hint=FieldHint.optional) + implicit_field: str = Field(default=None, hint=FieldHint.optional) + list_field: list[int] = Field(default_factory=list, hint=FieldHint.optional) + tuple_field: tuple[int, ...] = Field(default=(), hint=FieldHint.optional) + # tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) + dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) + type_field: type[int] = Field(default=int, hint=FieldHint.optional) + enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + core_field: int = Field(default=4, hint=FieldHint.core) + complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.implicit_field is None: + self.implicit_field = "implicit" + super()._validate() diff --git a/tests/config/test_field.py b/tests/config/test_field.py new file mode 100644 index 000000000..27e7c8b54 --- /dev/null +++ b/tests/config/test_field.py @@ -0,0 +1,176 @@ +import math +import pathlib + +import numpy +import pytest + +from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.utils import Assert +from tests.config.common import TestConfig, TestEnum + + +def check_config(internal_config, *alternate, serialized_config=None): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = TestConfig.from_dict(init_config) + Assert.eq(config.to_serialized(), serialized_config) + Assert.eq(config._to_dict(), internal_config) + + +def check_invalid_config(config): + with pytest.raises(ValidationError): + TestConfig.from_dict(config) + + +def test_create_and_serialize_config(): + Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + + +@pytest.mark.parametrize("value", (0, -6, 3, True)) +def test_int_field(value): + check_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +def test_int_field_invalid(value): + check_invalid_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (True, False)) +def test_bool_field(value): + check_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", (1, "True", None, [True])) +def test_bool_field_invalid(value): + check_invalid_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +def test_str_field(value): + check_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +def test_str_field_invalid(value): + check_invalid_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (".", "text", "/a/b/c.d")) +def test_path_field(value): + check_config({"path_field": pathlib.Path(value)}, {"path_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, [pathlib.Path("a")])) +def test_path_field_invalid(value): + check_invalid_config({"path_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +def test_float_field(value): + check_config({"float_field": value}) + + +@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +def test_float_field_invalid(value): + check_invalid_config({"float_field": value}) + + +@pytest.mark.parametrize("value", ("", None, "text")) +def test_optional_field(value): + check_config({"optional_field": value}) + + +@pytest.mark.parametrize("value", (True, 6, [None])) +def test_optional_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +def test_union_field(value): + check_config({"union_field": value}) + + +@pytest.mark.parametrize("value", (6.0, [""])) +def test_union_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("implicit", "", "text")) +def test_implicit_field(value): + check_config({"implicit_field": value}) + + +TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_list_field(value): + check_config( + {"list_field": list(value)}, + {"list_field": value}, + serialized_config={"list_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_tuple_field(value): + check_config( + {"tuple_field": list(value)}, + {"tuple_field": value}, + serialized_config={"tuple_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_set_field(value): + check_config( + {"set_field": list(set(value))}, + {"set_field": set(value)}, + {"set_field": list(value)}, + {"set_field": tuple(value)}, + serialized_config={"set_field": list(set(value))}, + ) + + +# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) +# def test_tuple_fixed_length_field(value): +# expected_config = {"tuple_variable_length_field": value} +# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) + + +@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +def test_dict_field(value): + check_config({"dict_field": value}) + + +class IntClass(int): + pass + + +@pytest.mark.parametrize("value", (int, bool, IntClass)) +def test_type_field(value): + check_config({"type_field": value}, serialized_config={"type_field": str(value)}) + + +@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +def test_enum_field(value): + check_config({"enum_field": value}, {"enum_field": value.value}) + + +def test_core_field(): + Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + + +@pytest.mark.parametrize( + "value", + ( + {}, + {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {0: [[".", 8]]}, + ), +) +def test_complex_field(value): + check_config({"complex_field": value}) diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 4ac2fcdf6..280b34137 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -8,5 +8,4 @@ def test_dataset_from_file(): get_test_dataset() dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - print("kjhbwiugfberibgiujebi", len(dataset)) compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9a15a051b..a6fd3246b 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -148,7 +148,6 @@ def test_split_datasets_1(): { "training": { "type": "blended", - "name": "blended", "datasets": [ dataset_config_0.to_serialized(), { diff --git a/tests/test_config.py b/tests/test_config.py index 7141812a2..5c45db0bd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,19 +1,14 @@ import pathlib -import pytest import subprocess import unittest.mock -import yaml +import pytest +import yaml -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerArchitectureConfig, - AddLinearBiasChoices, -) -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.config import ValidationError - +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry @@ -90,33 +85,6 @@ def test_do_use_flash_attention(): config.do_use_flash_attention(mock_distributed_config) -def test_add_linear_biases_valid_values(): - # Valid boolean values - assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False - - # Valid enum values - assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere - assert ( - TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases - == AddLinearBiasChoices.everywhere - ) - assert ( - TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv - ) - - -def test_add_linear_biases_invalid_values(): - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases="invalid_value") - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=123) - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=None) - - def test_add_mlp_bias(): assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False @@ -130,7 +98,9 @@ def test_add_attn_qkv_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + ) def test_add_attn_dense_bias(): @@ -138,4 +108,14 @@ def test_add_attn_dense_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias + is False + ) + + +@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) +def test_serialize_default_config_updates(cls): + # Config classes used as config updates should have a default that serializes to an empty dict + # so no value is incorrectly overridden. + assert cls.from_dict({}).to_serialized() == {} From c13fb19f8763b0aebe83058b375b9732e70721d2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 28 Mar 2025 22:28:57 -0400 Subject: [PATCH 009/122] misc --- fast_llm/config.py | 29 +++++----- fast_llm/data/dataset/gpt/config.py | 4 +- fast_llm/utils.py | 2 +- tests/config/common.py | 6 +- tests/config/test_field.py | 86 ++++++++++++++++++++++------- 5 files changed, 88 insertions(+), 39 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 67aa5b7a6..c311abf4e 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -468,10 +468,10 @@ def _validate_element(cls, value, type_, name: str): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): - cls._validate_element_type(value, type_, name) + cls._validate_element_type(value, type_, strict=False) value.validate(_is_validating=True) else: - value = cls._validate_simple(value, type_, name) + value = cls._validate_simple(value, type_) return value @classmethod @@ -491,7 +491,7 @@ def _validate_union(cls, value, type_, name: str): @classmethod def _validate_array(cls, value, type_, name: str): origin = type_.__origin__ - cls._validate_element_type(value, (origin, list, tuple), name) + cls._validate_element_type(value, (origin, list, tuple), strict=False) args = getattr(type_, "__args__", [typing.Any, ...] if origin is tuple else [typing.Any]) errors = [] if issubclass(origin, tuple) and not (len(args) == 2 and args[1] is ...): @@ -518,7 +518,7 @@ def _validate_dict(cls, value, type_, name: str): if len(args) > 2: raise FieldTypeError(f"Invalid dict specification `{get_type_name(type_)}` for field `{name}`") args.extend([typing.Any for _ in range(2 - len(args))]) - cls._validate_element_type(value, type_.__origin__, name) + cls._validate_element_type(value, type_.__origin__, strict=False) errors = [] new_value = {} old_keys = {} @@ -534,19 +534,22 @@ def _validate_dict(cls, value, type_, name: str): return new_value @classmethod - def _validate_simple(cls, value, type_, name: str): + def _validate_simple(cls, value, type_, strict: bool = True): if hasattr(type_, "__fast_llm_validator__"): value = type_.__fast_llm_validator__(value) - elif type_ is float and isinstance(value, int): + elif type_ is float and type(value) == int: # Ints are ok too. value = float(value) elif issubclass(type_, enum.Enum) and not isinstance(value, type_) and issubclass(type_, type(value)): # Enum values are ok too. value = type_(value) - elif issubclass(type_, pathlib.PurePath) and isinstance(value, str): - # Str paths are ok too. - value = type_(value) - cls._validate_element_type(value, type_, name) + elif issubclass(type_, pathlib.PurePath): + if isinstance(value, str): + # Str paths are ok too. + value = type_(value) + # Path type may depend on the OS. + strict = False + cls._validate_element_type(value, type_, strict) return value @classmethod @@ -560,9 +563,9 @@ def _validate_type(cls, value, type_: type | tuple[type, ...], name): raise ValidationError(f"Field value `{value} is not a subclass of `{get_type_name(type_)}`") @classmethod - def _validate_element_type(cls, value, type_: type | tuple[type, ...], name): - if not isinstance(value, type_): - raise ValidationError(f"Unexpected type `{get_type_name(type(value))}`") + def _validate_element_type(cls, value, type_: type | tuple[type, ...], strict: bool = True): + if not (type(value) == type_ if strict else isinstance(value, type_)): + raise ValidationError(f"Unexpected field type: {get_type_name(type(value))} != {get_type_name(type_)}") @classmethod def fields(cls) -> typing.Iterable[tuple[str, Field]]: diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 118b3039d..4f15492a9 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -484,8 +484,8 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: "type": "slice", # TODO: this duplicates memmap datasets for each phase. "dataset": {"type": "memmap", "path": prefix}, - "begin": phase_splits[phase_index], - "end": phase_splits[phase_index + 1], + "begin": float(phase_splits[phase_index]), + "end": float(phase_splits[phase_index + 1]), } for prefix in dataset_prefixes ] diff --git a/fast_llm/utils.py b/fast_llm/utils.py index aac6f6077..4edd8b98c 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -86,7 +86,7 @@ class Assert: @staticmethod def eq(x, *args, msg=None): for arg in args: - assert x == arg, f"{x} != {arg} " + f"| {msg}" if msg else "" + assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") @staticmethod def is_(x, y): diff --git a/tests/config/common.py b/tests/config/common.py index 3109175aa..143d770c3 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -4,14 +4,14 @@ from fast_llm.config import Config, Field, FieldHint, config_class -class TestEnum(str, enum.Enum): +class ExampleEnum(enum.StrEnum): a = "a" b = "b" c = "c" @config_class -class TestConfig(Config): +class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) str_field: str = Field(default="", hint=FieldHint.optional) @@ -26,7 +26,7 @@ class TestConfig(Config): set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) type_field: type[int] = Field(default=int, hint=FieldHint.optional) - enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional) core_field: int = Field(default=4, hint=FieldHint.core) complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 27e7c8b54..4f39f7414 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -6,32 +6,59 @@ from fast_llm.config import FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import TestConfig, TestEnum +from tests.config.common import ExampleConfig, ExampleEnum + + +def _check_equal(config_a, config_b): + # Check for equality of both values and types. + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key + Assert.eq(type(config_a[key]), type(config_b[key])) + if isinstance(config_a[key], (list, tuple, set)): + Assert.eq(len(config_a[key]), len(config_b[key])) + for i in range(len(config_a[key])): + _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) + elif isinstance(config_a[key], dict): + _check_equal(config_a[key], config_b[key]) + else: + try: + Assert.eq(config_a[key], config_b[key]) + except AssertionError as e: + # Special case for `math.nan` + if config_a[key] is not config_b[key]: + raise e + + +def check_equal(config_a, config_b): + try: + _check_equal(config_a, config_b) + except AssertionError as e: + raise AssertionError(config_a, config_b, *e.args) def check_config(internal_config, *alternate, serialized_config=None): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = TestConfig.from_dict(init_config) - Assert.eq(config.to_serialized(), serialized_config) - Assert.eq(config._to_dict(), internal_config) + config = ExampleConfig.from_dict(init_config) + check_equal(config.to_serialized(), serialized_config) + check_equal(config._to_dict(), internal_config) def check_invalid_config(config): with pytest.raises(ValidationError): - TestConfig.from_dict(config) + ExampleConfig.from_dict(config) def test_create_and_serialize_config(): - Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) -@pytest.mark.parametrize("value", (0, -6, 3, True)) +@pytest.mark.parametrize("value", (0, -6, 3)) def test_int_field(value): check_config({"int_field": value}) -@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4], True)) def test_int_field_invalid(value): check_invalid_config({"int_field": value}) @@ -46,12 +73,12 @@ def test_bool_field_invalid(value): check_invalid_config({"bool_field": value}) -@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +@pytest.mark.parametrize("value", ("", "text", "1")) def test_str_field(value): - check_config({"str_field": value}) + check_config({"str_field": str(value)}, {"str_field": value}) -@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a)) def test_str_field_invalid(value): check_invalid_config({"str_field": value}) @@ -66,12 +93,14 @@ def test_path_field_invalid(value): check_invalid_config({"path_field": value}) -@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, math.nan)) def test_float_field(value): - check_config({"float_field": value}) + check_config( + {"float_field": float(value)}, {"float_field": value}, serialized_config={"float_field": float(value)} + ) -@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +@pytest.mark.parametrize("value", (None, [4.7], "0.0", True, numpy.float64(3))) def test_float_field_invalid(value): check_invalid_config({"float_field": value}) @@ -86,16 +115,20 @@ def test_optional_field_invalid(value): check_invalid_config({"optional": value}) -@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +@pytest.mark.parametrize("value", ("", 0, "text", 7)) def test_union_field(value): check_config({"union_field": value}) -@pytest.mark.parametrize("value", (6.0, [""])) +@pytest.mark.parametrize("value", (6.0, [""], True)) def test_union_field_invalid(value): check_invalid_config({"optional": value}) +def test_implicit_field_value(): + Assert.eq(ExampleConfig.from_dict({}).implicit_field, "implicit") + + @pytest.mark.parametrize("value", ("implicit", "", "text")) def test_implicit_field(value): check_config({"implicit_field": value}) @@ -141,11 +174,16 @@ def test_set_field(value): # Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) -@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +@pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) def test_dict_field(value): check_config({"dict_field": value}) +@pytest.mark.parametrize("value", ({True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text")) +def test_dict_field_invalid(value): + check_invalid_config({"dict_field": value}) + + class IntClass(int): pass @@ -155,22 +193,30 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) -@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +@pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): check_config({"enum_field": value}, {"enum_field": value.value}) def test_core_field(): - Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( "value", ( {}, - {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {3: None, "text": [], 0: [["", 3], ["a", -7]]}, {0: [[".", 8]]}, ), ) def test_complex_field(value): check_config({"complex_field": value}) + + +@pytest.mark.parametrize( + "value", + ({3: None, "text": [], False: [["", 3], ["a", -7]]},), +) +def test_complex_field_invalid(value): + check_invalid_config({"complex_field": value}) From a20fcecfb870fb076bfa067b8622c6a31aa4d928 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 31 Mar 2025 20:23:42 -0400 Subject: [PATCH 010/122] tests --- tests/config/common.py | 21 ++++++ tests/config/test_field.py | 133 +++++++++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 34 deletions(-) diff --git a/tests/config/common.py b/tests/config/common.py index 143d770c3..f94495073 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -35,3 +35,24 @@ def _validate(self) -> None: if self.implicit_field is None: self.implicit_field = "implicit" super()._validate() + + +@config_class +class ExampleVerboseConfig(Config): + # These fields will have non-empty default serialized values. + list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) + tuple_default_field: tuple[int, ...] = Field(default=(0, 1), hint=FieldHint.optional) + tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_default_field: set[int] = Field(default_factory=lambda: {0, 1, 2}, hint=FieldHint.optional) + dict_default_field: dict[str, int] = Field(default_factory=lambda: {"0": 0, "1": 1}, hint=FieldHint.optional) + explicit_field: str = Field(default=None, hint=FieldHint.optional) + + def _validate(self) -> None: + if self.explicit_field is None: + self.explicit_field = "explicit" + super()._validate() + + +@config_class +class ExampleNestedConfig(ExampleConfig): + nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 4f39f7414..bed9c181a 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,29 +4,29 @@ import numpy import pytest -from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.config import Config, FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig def _check_equal(config_a, config_b): # Check for equality of both values and types. - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - Assert.eq(type(config_a[key]), type(config_b[key])) - if isinstance(config_a[key], (list, tuple, set)): - Assert.eq(len(config_a[key]), len(config_b[key])) - for i in range(len(config_a[key])): - _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) - elif isinstance(config_a[key], dict): + Assert.eq(type(config_a), type(config_b)) + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key _check_equal(config_a[key], config_b[key]) - else: - try: - Assert.eq(config_a[key], config_b[key]) - except AssertionError as e: - # Special case for `math.nan` - if config_a[key] is not config_b[key]: - raise e + elif isinstance(config_a, (list, tuple, set)): + Assert.eq(len(config_a), len(config_b)) + for i in range(len(config_a)): + _check_equal(config_a[i], config_b[i]) + else: + try: + Assert.eq(config_a, config_b) + except AssertionError: + # Special case for `math.nan` + if config_a is not config_b: + raise def check_equal(config_a, config_b): @@ -36,17 +36,30 @@ def check_equal(config_a, config_b): raise AssertionError(config_a, config_b, *e.args) -def check_config(internal_config, *alternate, serialized_config=None): +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = ExampleConfig.from_dict(init_config) - check_equal(config.to_serialized(), serialized_config) - check_equal(config._to_dict(), internal_config) + config = cls.from_dict(init_config) + serialized_config_ = config.to_serialized() + internal_config_ = config._to_dict() + if fields is None: + check_equal(serialized_config_, serialized_config) + check_equal(internal_config_, internal_config) + else: + for field in fields: + check_equal(serialized_config_[field], serialized_config[field]) + check_equal(internal_config_[field], internal_config[field]) -def check_invalid_config(config): +def check_invalid_config(config, cls: type[Config] = ExampleConfig): with pytest.raises(ValidationError): - ExampleConfig.from_dict(config) + cls.from_dict(config) def test_create_and_serialize_config(): @@ -134,10 +147,11 @@ def test_implicit_field(value): check_config({"implicit_field": value}) -TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES_INVALID = (6.0, {}, True, "text") -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_list_field(value): check_config( {"list_field": list(value)}, @@ -146,7 +160,12 @@ def test_list_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_list_field_invalid(value): + check_invalid_config({"list_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_tuple_field(value): check_config( {"tuple_field": list(value)}, @@ -155,7 +174,12 @@ def test_tuple_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"tuple_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_set_field(value): check_config( {"set_field": list(set(value))}, @@ -166,12 +190,9 @@ def test_set_field(value): ) -# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) -# def test_tuple_fixed_length_field(value): -# expected_config = {"tuple_variable_length_field": value} -# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"set_field": value}) @pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) @@ -193,9 +214,19 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_type_field_invalid(value): + check_invalid_config({"type_field": value}) + + @pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): - check_config({"enum_field": value}, {"enum_field": value.value}) + check_config({"enum_field": value}, {"enum_field": str(value)}) + + +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_enum_field_invalid(value): + check_invalid_config({"type_field": value}) def test_core_field(): @@ -220,3 +251,37 @@ def test_complex_field(value): ) def test_complex_field_invalid(value): check_invalid_config({"complex_field": value}) + + +def test_verbose_config_default(): + default_values = { + "list_default_field": [0], + "tuple_default_field": [0, 1], + "tuple_fixed_length_field": [5, "text"], + "set_default_field": [0, 1, 2], + "dict_default_field": {"0": 0, "1": 1}, + "explicit_field": "explicit", + } + config = ExampleVerboseConfig.from_dict({}) + check_equal(config.to_serialized(), default_values) + check_equal(config._to_dict(), default_values) + + +@pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) +def test_tuple_fixed_length_field(value): + check_config( + {"tuple_fixed_length_field": list(value)}, + {"tuple_fixed_length_field": value}, + serialized_config={"tuple_fixed_length_field": list(value)}, + cls=ExampleVerboseConfig, + fields=["tuple_fixed_length_field"], + ) + + +@pytest.mark.parametrize("value", ((), (5,), ("", 0), ("0", "True"), (0, "", "text"))) +def test_tuple_fixed_length_field_invalid(value): + check_invalid_config({"tuple_fixed_length_field": value}, cls=ExampleVerboseConfig) + + +# TODO: Test other fields with defaults. +# TODO: Test nested fields. From 9af372df69e71e7a818bb52f4e7d26706d42e19c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 18:07:02 -0400 Subject: [PATCH 011/122] Tests, fixes, remove tuple format --- fast_llm/config.py | 111 ++++++------------ fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 2 +- fast_llm/engine/checkpoint/distributed.py | 8 +- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 2 +- fast_llm/engine/checkpoint/state_dict.py | 2 +- fast_llm/engine/config_utils/run.py | 4 +- fast_llm/engine/huggingface/config.py | 4 +- fast_llm/engine/training/wandb.py | 2 +- fast_llm/utils.py | 32 +++++ tests/config/common.py | 31 ++++- tests/config/test_field.py | 67 ++--------- tests/config/test_update.py | 52 ++++++++ tests/data/test_prepare_gpt_memmap.py | 20 ++-- tests/test_config.py | 2 +- tools/moe_add_experts.py | 2 +- 18 files changed, 185 insertions(+), 162 deletions(-) create mode 100644 tests/config/test_update.py diff --git a/fast_llm/config.py b/fast_llm/config.py index c311abf4e..0abd90737 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -11,7 +11,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log +from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -38,13 +38,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): _AUTO_VALIDATE = self._old_value -class _ConfigDictFormat(str, enum.Enum): - # TODO v0.3: delete class - flat = "flat" - nested = "nested" - tuple = "tuple" - - class UpdateType(str, enum.Enum): # Override entries no matter what they contais. override = "override" @@ -578,33 +571,26 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]: def get_field(cls, name: str) -> Field: return cls.__dataclass_fields__[name] # noqa - def _to_dict( + def to_dict( self, verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serialized: bool = True, ) -> dict[str, typing.Any]: """ Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`. - When not flat, the dict includes a `__class__` entry which allows support for derived classes. Args: all_fields: Include the derived fields, with `init=False`. - format_: The config format used to represent nested configs. Options: - * `ConfigDictFormat.nested`: Preserve the nested config structure by returning nested dicts. - Also save a `__class__` entry to support derived classes. Standard format. - * `ConfigDictFormat.tuple`: Preserve the nested config structure by returning tuples of keys. - Used for config updates. - serializable: Ensure the dict is serializable to json or yaml. Information may be lost. + serialized: Ensure the dict is serializable to json or yaml. Information may be lost. """ arg_dict = {} for name, field in self.fields(): value = getattr(self, name, MISSING) - self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, serialized) if hasattr(self, "_unknown_fields"): for name, value in self._unknown_fields.items(): - self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, serialized) return arg_dict @@ -616,13 +602,12 @@ def _add_field_to_args( value: typing.Any, verbose: int | None = None, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serializable: bool = True, ) -> None: if ( field is not None and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not (all_fields) + and not all_fields ): # Exclude class variables and derived fields unless requested explicitly. return @@ -632,48 +617,36 @@ def _add_field_to_args( or (verbose is not None and verbose >= FieldHintImportance[field.hint]) ) if isinstance(value, Config): - field_value = value._to_dict( + field_value = value.to_dict( verbose=verbose, all_fields=all_fields, - format_=format_, - serializable=serializable, + serialized=serializable, ) # Empty configs can safely be trimmed. explicit_field = all_fields elif isinstance(value, (list, tuple, set)): - field_value = {} if format_ == _ConfigDictFormat.tuple else [] + field_value = [] for i, list_value in enumerate(value): - self._add_field_to_args( - field_value, str(i), None, list_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, str(i), None, list_value, verbose, all_fields, serializable) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - self._add_field_to_args( - field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, dict_name, None, dict_value, verbose, all_fields, serializable) elif explicit_field: field_value = value if serializable: field_value = self._serialize_value(value) - if format_ == _ConfigDictFormat.tuple: - field_value = {(): field_value} else: # Exclude unimportant (implicit or explicit) default values. return if serializable: name = self._serialize_value(name) - if format_ == _ConfigDictFormat.tuple: - args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) - elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: - if isinstance(args, dict): - args[name] = field_value - else: - args.append(field_value) - else: - raise NotImplementedError(format_) + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: + if isinstance(args, dict): + args[name] = field_value + else: + args.append(field_value) @classmethod def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: @@ -689,12 +662,14 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: return value def to_copy[ - T - ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: - return self.from_dict(self, *updates, strict=strict) - - def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: - return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) + T: Config, + ]( + self: T, + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + update_type: UpdateType = UpdateType.override, + ) -> T: + return self.from_dict(self, *updates, strict=strict, update_type=update_type) def to_logs[ T @@ -706,7 +681,7 @@ def to_logs[ width: int = 80, fill_char: str = "-", ) -> T: - arg_dict = self.to_serialized(verbose=verbose) + arg_dict = self.to_dict(verbose=verbose) if title is None: title = self._get_class_name() return log_fn( @@ -728,12 +703,14 @@ def from_dict( update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): - default = default._to_dict() + default = default.to_dict(serialized=False) else: default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): - update = update._to_dict(format_=_ConfigDictFormat.tuple) + update = update.to_dict(serialized=False) + else: + update = copy.deepcopy(update) for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) @@ -878,27 +855,15 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - other_dict = other._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - compare = { - key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) - for key in self_dict.keys() | other_dict.keys() - } - diff = { - key: (self_value, other_value) - for key, (self_value, other_value) in compare.items() - if self_value != other_value - } - if diff: - log( + self_dict = self.to_dict(verbose=FieldVerboseLevel.everything) + other_dict = other.to_dict(verbose=FieldVerboseLevel.everything) + errors = compare_nested(self_dict, other_dict) + if errors: + return log( f"Config diff:\n " + "\n ".join( f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in diff.items() + for key, (self_value, other_value) in errors.items() ), log_fn=log_fn, ) @@ -1005,7 +970,7 @@ def set_nested_dict_value[ else: d[key] = {} for key_, value_ in value.items(): - set_nested_dict_value(d, key_, value_, update_type) + set_nested_dict_value(d[key], key_, value_, update_type) elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9c5e6f13c..0958f1185 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -505,7 +505,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: dataset_config = { "type": "fim", "dataset": dataset_config, - **self.fim.to_serialized(), + **self.fim.to_dict(), } # Legacy sampling config dataset_config = { diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f5d230312..25529ef08 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -154,7 +154,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, - "config": self._config.to_serialized(), + "config": self._config.to_dict(), } self._load_yaml_data(yaml_data) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df1..23e497bf8 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -281,7 +281,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( - dataset_config.to_serialized(), + dataset_config.to_dict(), output_path.open("w"), ) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 503839f0d..f27fff5dd 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -32,7 +32,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: - serialized_metadata = metadata.to_serialized() + serialized_metadata = metadata.to_dict() if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( @@ -50,10 +50,8 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) - same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) - and config.optimizer_state - ) + # Using `log_fn=bool` sets the output to true if the error list is non-empty. + same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 98cab927e..654ba21fd 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -232,7 +232,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada fast_llm_version=__version__, model=cls._model_class, format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_serialized()}), + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), shards=["weights"], ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4e..f335015a6 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -34,7 +34,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch huggingface_config = self._export_config(self._model.config.base_model) self._save_config(config.path, huggingface_config) return { - "fast_llm_metadata": metadata.to_serialized(), + "fast_llm_metadata": metadata.to_dict(), "model_config": huggingface_config, "format": "pt", } diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 5288d49f6..71c83ece3 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -71,7 +71,7 @@ def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metada def _serialize_metadata( self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: - return metadata.to_serialized() + return metadata.to_dict() def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 0ac463396..d63774092 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -147,8 +147,8 @@ def __init__( self._is_pipeline_parallel_main_rank = ( self._distributed_config.data_rank == 0 and self._distributed_config.tensor_rank == 0 ) - config_dict = config.to_serialized() - config_dict_verbose = config.to_serialized(verbose=FieldVerboseLevel.performance) + config_dict = config.to_dict() + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 2b240e4b4..d4b46bcc0 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -91,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index e3d421a30..185b89c28 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -40,7 +40,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): if wandb_path is not None: yaml.safe_dump(wandb_config, wandb_path.open("w")) # TODO: Does wandb work with nested configs? - self._wandb = wandb.init(config=experiment_config.to_serialized(), **wandb_config) + self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config) else: self._wandb = None diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 4edd8b98c..da083eef2 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -289,3 +289,35 @@ def new_decorator(*args, **kwargs): return out return new_decorator + + +def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): + if errors is None: + errors = [] + # Check for equality of both values and types. + if type(config_a) != type(config_b): + errors.append(f"Type mismatch for key `{".".join(prefix)}`: {type(config_a)} != {type(config_b)}") + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + key_ = prefix + (key,) + if key not in config_a: + errors.append(f"Key `{".".join(key_)}` missing in lhs.") + elif key not in config_b: + errors.append(f"Key `{".".join(key_)}` missing in rhs.") + else: + compare_nested(config_a[key], config_b[key], errors, key_) + elif isinstance(config_a, (list, tuple, set)): + if len(config_a) != len(config_b): + errors.append(f"Length mismatch for key `{".".join(prefix)}`: {len(config_a)} != {len(config_b)}.") + else: + for i in range(len(config_a)): + compare_nested(config_a[i], config_b[i], errors, prefix + (str(i),)) + elif config_a != config_b and config_a is not config_b: + # `is not` needed for special cases like `math.nan` + errors.append(f"Different value for key `{".".join(prefix)}`: {config_a} != {config_b}.") + return errors + + +def check_equal_nested(config_a, config_b): + if errors := compare_nested(config_a, config_b): + raise ValueError("\n".join(errors)) diff --git a/tests/config/common.py b/tests/config/common.py index f94495073..a26579262 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -1,7 +1,10 @@ import enum import pathlib -from fast_llm.config import Config, Field, FieldHint, config_class +import pytest + +from fast_llm.config import Config, Field, FieldHint, ValidationError, config_class +from fast_llm.utils import check_equal_nested class ExampleEnum(enum.StrEnum): @@ -56,3 +59,29 @@ def _validate(self) -> None: @config_class class ExampleNestedConfig(ExampleConfig): nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + + +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = cls.from_dict(init_config) + serialized_config_ = config.to_dict() + internal_config_ = config.to_dict(serialized=False) + if fields is None: + check_equal_nested(serialized_config_, serialized_config) + check_equal_nested(internal_config_, internal_config) + else: + for field in fields: + check_equal_nested(serialized_config_[field], serialized_config[field]) + check_equal_nested(internal_config_[field], internal_config[field]) + + +def check_invalid_config(config, cls: type[Config] = ExampleConfig): + with pytest.raises(ValidationError): + cls.from_dict(config) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index bed9c181a..91b5c0d82 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,66 +4,13 @@ import numpy import pytest -from fast_llm.config import Config, FieldVerboseLevel, ValidationError -from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig - - -def _check_equal(config_a, config_b): - # Check for equality of both values and types. - Assert.eq(type(config_a), type(config_b)) - if isinstance(config_a, dict): - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - _check_equal(config_a[key], config_b[key]) - elif isinstance(config_a, (list, tuple, set)): - Assert.eq(len(config_a), len(config_b)) - for i in range(len(config_a)): - _check_equal(config_a[i], config_b[i]) - else: - try: - Assert.eq(config_a, config_b) - except AssertionError: - # Special case for `math.nan` - if config_a is not config_b: - raise - - -def check_equal(config_a, config_b): - try: - _check_equal(config_a, config_b) - except AssertionError as e: - raise AssertionError(config_a, config_b, *e.args) - - -def check_config( - internal_config, - *alternate, - serialized_config=None, - cls: type[Config] = ExampleConfig, - fields: list[str] | None = None, -): - serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config - for init_config in (internal_config, *alternate): - config = cls.from_dict(init_config) - serialized_config_ = config.to_serialized() - internal_config_ = config._to_dict() - if fields is None: - check_equal(serialized_config_, serialized_config) - check_equal(internal_config_, internal_config) - else: - for field in fields: - check_equal(serialized_config_[field], serialized_config[field]) - check_equal(internal_config_[field], internal_config[field]) - - -def check_invalid_config(config, cls: type[Config] = ExampleConfig): - with pytest.raises(ValidationError): - cls.from_dict(config) +from fast_llm.config import FieldVerboseLevel +from fast_llm.utils import Assert, check_equal_nested +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig, check_config, check_invalid_config def test_create_and_serialize_config(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(), {}) @pytest.mark.parametrize("value", (0, -6, 3)) @@ -230,7 +177,7 @@ def test_enum_field_invalid(value): def test_core_field(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( @@ -263,8 +210,8 @@ def test_verbose_config_default(): "explicit_field": "explicit", } config = ExampleVerboseConfig.from_dict({}) - check_equal(config.to_serialized(), default_values) - check_equal(config._to_dict(), default_values) + check_equal_nested(config.to_dict(), default_values) + check_equal_nested(config.to_dict(serialized=False), default_values) @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) diff --git a/tests/config/test_update.py b/tests/config/test_update.py new file mode 100644 index 000000000..ad534d49e --- /dev/null +++ b/tests/config/test_update.py @@ -0,0 +1,52 @@ +import pytest + +from fast_llm.config import UpdateType +from fast_llm.utils import check_equal_nested +from tests.config.common import ExampleNestedConfig + +TEST_CONFIGS = ( + ( + # Empty config + {}, + {}, + {}, + None, + ), + ( + # Update unset field; don't update set field; update + {"int_field": 4, "str_field": "text"}, + {"float_field": 3.0, "str_field": ""}, + {"int_field": 4, "float_field": 3.0, "str_field": ""}, + None, + ), + ( + # Update/override nested field. + {"nested_field": {"int_field": 4, "str_field": "text"}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + {"nested_field": {"int_field": 4, "float_field": 3.0, "str_field": ""}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + ), + # TODO: Add more complex cases +) + + +@pytest.mark.parametrize(("base", "update", "updated", "overridden"), TEST_CONFIGS) +def test_update(base, update, updated, overridden) -> None: + if overridden is None: + overridden = updated + check_equal_nested(ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.update).to_dict(), updated) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.update) + .to_dict(), + updated, + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.override).to_dict(), overridden + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.override) + .to_dict(), + overridden, + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index a6fd3246b..9dd7975c2 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -95,20 +95,20 @@ def test_split_dataset(): {"training": 3, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { "training": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0, "end": 0.75, }, "validation": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0.75, "end": 1, }, @@ -124,13 +124,13 @@ def test_split_datasets_0(): {"training": 1, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { - "training": dataset_config_0.to_serialized(), - "validation": dataset_config_1.to_serialized(), + "training": dataset_config_0.to_dict(), + "validation": dataset_config_1.to_dict(), }, ) @@ -141,7 +141,7 @@ def test_split_datasets_1(): config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, @@ -149,10 +149,10 @@ def test_split_datasets_1(): "training": { "type": "blended", "datasets": [ - dataset_config_0.to_serialized(), + dataset_config_0.to_dict(), { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0, "end": 0.5, }, @@ -161,7 +161,7 @@ def test_split_datasets_1(): }, "validation": { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0.5, "end": 1, }, diff --git a/tests/test_config.py b/tests/test_config.py index 5c45db0bd..79437e9d6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,4 +118,4 @@ def test_add_attn_dense_bias(): def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_serialized() == {} + assert cls.from_dict({}).to_dict() == {} diff --git a/tools/moe_add_experts.py b/tools/moe_add_experts.py index 975ece86f..69311017f 100644 --- a/tools/moe_add_experts.py +++ b/tools/moe_add_experts.py @@ -93,7 +93,7 @@ def run(self): model.save_pretrained(self.output_dir, state_dict=state_dict) # Save surgery config as yaml - yaml.safe_dump(self.to_serialized(), (self.output_dir / "surgery_config.yaml").open("w")) + yaml.safe_dump(self.to_dict(), (self.output_dir / "surgery_config.yaml").open("w")) logger.info("Done!") From dded00af39930f7cc57ade985dd65e314e3b62a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:19:15 -0400 Subject: [PATCH 012/122] fix --- fast_llm/config.py | 10 ++++------ fast_llm/utils.py | 5 +++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 0abd90737..62db786dd 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -391,9 +391,11 @@ def _validate(self) -> None: if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue value = getattr(self, name) - if value is DEFAULT: + if isinstance(value, Tag): + Assert.is_(value, DEFAULT) # Replace the value with its default. # We still need to validate because some fields have invalid defaults. + # TODO: Improve (still needed with new config update format? Do earlier to allow implicit defaults?) value = field.default new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) setattr(self, name, new_value) @@ -860,11 +862,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ errors = compare_nested(self_dict, other_dict) if errors: return log( - f"Config diff:\n " - + "\n ".join( - f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in errors.items() - ), + f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index da083eef2..a8c5eac61 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -71,12 +71,17 @@ def rms_diff(x: "torch.Tensor", y: "torch.Tensor") -> "torch.Tensor": class Tag: + __slots__ = ("value",) + def __init__(self, value: str): self.value = value def __repr__(self) -> str: return self.value + def __deepcopy__(self, memodict: dict[str, typing.Any]) -> typing.Self: + return self + class Assert: """ From 986f9f3c9a5ebdc40dd9879540449a0fdb2aa80f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:27:32 -0400 Subject: [PATCH 013/122] fix --- tests/test_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d5685a719..d446f4142 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -409,6 +409,7 @@ def test_load_pretrained_distributed_with_config(): ) +@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" From 8e3e7957b759d17c194d78edf736af7136d0586d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 21:21:25 -0400 Subject: [PATCH 014/122] fixes --- fast_llm/engine/checkpoint/distributed.py | 2 +- tests/common.py | 4 ++-- tests/test_checkpoint.py | 3 --- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index e3cd7d162..4225a4045 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -48,7 +48,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/tests/common.py b/tests/common.py index 14ec5c61b..cc7499019 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", + # "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.022", + "--init-method-std=0.0625", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d446f4142..6793a6700 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -259,7 +259,6 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, - load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config.base_model, model.config.base_model) @@ -409,7 +408,6 @@ def test_load_pretrained_distributed_with_config(): ) -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" @@ -454,7 +452,6 @@ def test_load_pretrained_in_dp2_match_checkpoint(): assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.slow @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): From da6eb7bf7b16b709c81f06df50a5cac342ee7915 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 01:17:37 -0400 Subject: [PATCH 015/122] fixes --- fast_llm/data/dataset/gpt/sampled.py | 4 +- fast_llm/engine/checkpoint/config.py | 30 +++++- fast_llm/engine/checkpoint/distributed.py | 24 +++-- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 5 +- fast_llm/engine/checkpoint/state_dict.py | 4 +- fast_llm/engine/huggingface/config.py | 5 +- fast_llm/engine/multi_stage/fast_llm_model.py | 7 +- fast_llm/engine/training/trainer.py | 1 + tests/common.py | 6 +- tests/test_checkpoint.py | 95 +++++++++++-------- tests/test_config.py | 11 ++- 12 files changed, 124 insertions(+), 70 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index c96eb35f6..fa4862161 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray: def _lazy_load(self): if self._array is None: - assert self.exists() + assert self.exists(), self._path self._array = np.load(self._path, mmap_mode="r") @@ -432,7 +432,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 7dbd5ce73..55440a5cc 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -202,6 +202,17 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False + load_config: ModelConfigType = Field( + default=ModelConfigType.model, + desc="Configuration to save/load.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + super()._validate() + if self.format.enforce_architecture_match: + assert self.load_config.load_architecture + @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): @@ -225,8 +236,23 @@ def __init__(self, model: "FastLLMModel"): # TODO: save_metadata? @classmethod - @abc.abstractmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + updates = {} + metadata = cls._load_metadata(config) + if not config.load_config.load_fast_llm: + updates[("config", "multi_stage")] = {} + updates[("config", "distributed")] = {} + if not config.load_config.load_architecture: + updates[("config", "base_model")] = {} + elif not config.load_config.load_base_model: + updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + if updates: + metadata = metadata.to_copy(updates) + return metadata + + @classmethod + @abc.abstractmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": pass @abc.abstractmethod @@ -234,7 +260,7 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"): pass @abc.abstractmethod - def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"): + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: pass def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]: diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 4225a4045..ac06df5c4 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,6 +13,7 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, + ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -27,7 +28,7 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: @@ -40,15 +41,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No metadata=export_safetensors_metadata(serialized_metadata), ) - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: More safety checks + loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm})) shard_names = self.get_shard_names(config) # Make sure all shards to load are in the checkpoint. - Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) - Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) + Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards)) + Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not loaded_metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) @@ -67,7 +69,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) for shard_name in shard_names: self._model.get_shard(shard_name).copy_( - f.get_slice("state_shard")[metadata.shards.index(shard_name)] + f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] ) else: # TODO: Does this copy twice? @@ -76,11 +78,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No else: log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) - self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) + self._model.config.base_model.compare_architecture(loaded_metadata.config.base_model, logger.warning) with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context: - for rank in range(metadata.config.distributed.world_size): + for rank in range(loaded_metadata.config.distributed.world_size): loaded_model = self._model.__class__( - metadata.config.to_copy({("distributed", "rank"): rank}), + loaded_metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) @@ -94,7 +96,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO v0.3: Use checkpoint version? Drop support? log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) loaded_shards = { - shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)] + shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] for shard_name in shard_names } else: @@ -119,3 +121,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) context.mark_as_loaded(counter.item()) + + return loaded_metadata.metadata diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 654ba21fd..e3b6dcf25 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -226,7 +226,7 @@ def __init__(self, model: "FastLLMModel"): } @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: imported_model_config = cls._import_config(cls._load_config(config.path), True) return CheckpointMetadata( fast_llm_version=__version__, diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 7357b722a..a5777d45f 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -39,10 +39,11 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch "format": "pt", } - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: assert not config.optimizer_state + metadata = self._model.config.load_metadata(config) self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) - super().load(config, metadata) + super().load(config) @classmethod def get_huggingface_model_type(self) -> str: diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 71c83ece3..1bb47e5c3 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -73,7 +73,7 @@ def _serialize_metadata( ) -> dict[str, typing.Any]: return metadata.to_dict() - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, @@ -116,7 +116,7 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = FastLLMCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: path = config.path / f"metadata.yaml" logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 080708049..d4b46bcc0 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -74,7 +74,10 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) + fast_llm_config = cls.model_config_class.from_metadata( + pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates + ) + config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index e2255faaa..de26f9bf6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -31,16 +31,15 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) - def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any]: + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, metadata) + metadata = converter.load(config) self._finalize_load(reset_optimizer=not config.optimizer_state) - return metadata.metadata + return metadata @classmethod def from_pretrained( diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f2ed4a38f..c6daa0813 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -494,6 +494,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> metadata = self._multi_stage.load_checkpoint( config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) + assert metadata is not None self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. diff --git a/tests/common.py b/tests/common.py index cc7499019..9ecb60ff9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - # "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.0625", + "--init-method-std=0.022", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", @@ -394,7 +394,7 @@ def run_test_script( if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(script) else: - completed_proc = subprocess.run(command, env=env) + completed_proc = subprocess.run(command, env=env, timeout=30) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") if compare: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 6793a6700..4171581a0 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,7 +14,7 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -246,8 +246,12 @@ def test_converted_huggingface(): assert (h0[key] == h1[key]).all() -def _compare_configs(config_ref, config_test): - config_ref.compare(config_test) +def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.compare(config_test.base_model) + + +def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.get_architecture().compare(config_test.base_model.get_architecture()) @pytest.mark.depends(on=["test_converted_distributed"]) @@ -261,7 +265,7 @@ def test_load_pretrained_distributed_checkpoint(): optimizer_state=True, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) state_shards = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) ) @@ -271,20 +275,24 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_0", - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_1", - format=DistributedCheckpointFormat, + + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_0", + format=DistributedCheckpointFormat, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_1", + format=DistributedCheckpointFormat, + ) + ) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -293,14 +301,17 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) - pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + ) + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + ) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + ) + _compare_architectures(config_ref, model.config) + _compare_architectures(config_ref, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -309,23 +320,27 @@ def test_load_converted_fast_llm_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + ) ) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_1", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ), + mode=StageMode.weights, ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_0", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -423,7 +438,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(): ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) - _compare_configs(config_ref.base_model, config_test.base_model) + _compare_model_configs(config_ref, config_test) shards_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] ref_model = TEST_MODEL_CLS(config_ref) @@ -467,7 +482,7 @@ def test_load_distributed_checkpoint_dp2(): ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] diff --git a/tests/test_config.py b/tests/test_config.py index 79437e9d6..ed758965e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,6 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.utils import check_equal_nested def run_without_import(cmd: str): @@ -114,8 +116,11 @@ def test_add_attn_dense_bias(): ) -@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) -def test_serialize_default_config_updates(cls): +@pytest.mark.parametrize( + ("cls", "default"), + ((GPTSamplingConfig, {}), (GPTModelConfig, {"distributed": {"world_size": 1, "rank": 0, "local_world_size": 1}})), +) +def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_dict() == {} + check_equal_nested(cls.from_dict({}).to_dict(), default) From baad705d6960d9578a2f5e29664284250d569980 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 19:16:01 -0400 Subject: [PATCH 016/122] fix --- fast_llm/layers/transformer/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a1cb658e0..cf409e773 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -186,7 +186,7 @@ class TransformerSubLayerName(str, enum.Enum): @config_class() class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( - default_factory=lambda: [TransformerSubLayerName.query, TransformerSubLayerName.value_], + default=None, desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) @@ -220,6 +220,15 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter def _validate(self) -> None: + if self.layers is None: + with self._set_implicit_default(): + # Setting the default layers only whee PeFT is enabled + # so they don't appear when serializing the default transformer config. + self.layers = ( + [TransformerSubLayerName.query, TransformerSubLayerName.value_] + if self.type == PeftType.lora + else [] + ) if self.type != PeftType.none: if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: # TODO: Add MLP support. From b7028378a2f8cb4e6e863ac55af69b0f11f71cff Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 4 Apr 2025 21:48:10 -0400 Subject: [PATCH 017/122] Test, fixes --- fast_llm/engine/checkpoint/config.py | 11 +-- fast_llm/engine/checkpoint/distributed.py | 7 ++ fast_llm/engine/checkpoint/huggingface.py | 6 +- fast_llm/engine/checkpoint/state_dict.py | 17 ++++- fast_llm/engine/multi_stage/config.py | 14 ++-- tests/test_config.py | 84 ++++++++++++++++++++++- 6 files changed, 123 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 55440a5cc..62928ed07 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -201,9 +201,9 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False - + # TODO: Set default to model? (Not backward compatible) load_config: ModelConfigType = Field( - default=ModelConfigType.model, + default=ModelConfigType.architecture, desc="Configuration to save/load.", hint=FieldHint.core, ) @@ -233,7 +233,10 @@ class CheckpointHandler(abc.ABC): def __init__(self, model: "FastLLMModel"): self._model = model - # TODO: save_metadata? + @classmethod + @abc.abstractmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: "CheckpointMetadata"): + pass @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -245,7 +248,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetad if not config.load_config.load_architecture: updates[("config", "base_model")] = {} elif not config.load_config.load_base_model: - updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + updates[("config", "base_model")] = metadata.config.base_model.get_architecture().to_dict() if updates: metadata = metadata.to_copy(updates) return metadata diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index ac06df5c4..de1625f6b 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -12,6 +12,7 @@ CheckpointLoadConfig, CheckpointLoadMetadataConfig, CheckpointSaveConfig, + CheckpointSaveMetadataConfig, DistributedCheckpointFormat, ModelConfigType, export_safetensors_metadata, @@ -27,6 +28,12 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat + @classmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): + config.path.mkdir(parents=True, exist_ok=True) + serialized_metadata = metadata.to_dict() + yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index a5777d45f..2972a4fa3 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -20,8 +20,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: - path = config.path / f"{self.base_file_name}.safetensors.index.json" + @classmethod + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + config.path.mkdir(parents=True, exist_ok=True) + path = config.path / f"{cls.base_file_name}.safetensors.index.json" logger.info(f"Saving index to {path}") # Save the index. json.dump( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 1bb47e5c3..556e97be9 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -30,6 +30,13 @@ class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" + @classmethod + def save_metadata( + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata, index: dict | None = None + ): + serialized_metadata = cls._serialize_metadata(config, metadata) + cls._save_serialized_metadata(config, serialized_metadata, {} if index is None else index) + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = self._serialize_metadata(config, metadata) saver = StateDictSaver( @@ -64,12 +71,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + @classmethod @abc.abstractmethod - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: pass + @classmethod def _serialize_metadata( - self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: return metadata.to_dict() @@ -121,9 +130,11 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) + @classmethod def _save_serialized_metadata( - self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict + cls, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict ) -> None: + config.path.mkdir(parents=True, exist_ok=True) path = config.path / f"metadata.yaml" logger.info(f"Saving metadata to {path}") if "metadata" not in serialized_metadata: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 43b412fbd..6a0c88137 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -187,11 +187,12 @@ class MultiStageConfig(StageConfig): def _validate(self) -> None: super()._validate() if self.zero_stage is not None: - Assert.in_range_incl(self.zero_stage, 1, 3) - if self.zero_stage >= 2: - self.num_grad_buffers = 2 - if self.zero_stage >= 3: - self.num_weight_buffers = 2 + with self._set_implicit_default(): + Assert.in_range_incl(self.zero_stage, 1, 3) + if self.zero_stage >= 2: + self.num_grad_buffers = 2 + if self.zero_stage >= 3: + self.num_weight_buffers = 2 if self.num_grad_buffers is not None: Assert.geq(self.num_grad_buffers, 1) if self.num_weight_buffers is not None: @@ -281,6 +282,9 @@ def to_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> "Checkp **kwargs, ) + def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None: + self.get_checkpoint_handler_class(config.format).save_metadata(config, self.to_metadata(config, **kwargs)) + @config_class() class PretrainedFastLLMModelConfig(Config): diff --git a/tests/test_config.py b/tests/test_config.py index ed758965e..79c6738d7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,13 +5,16 @@ import pytest import yaml +from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry -from fast_llm.models.gpt.config import GPTModelConfig -from fast_llm.utils import check_equal_nested +from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.utils import Assert, check_equal_nested +from tests.common import TEST_RESULTS_PATH def run_without_import(cmd: str): @@ -124,3 +127,80 @@ def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. check_equal_nested(cls.from_dict({}).to_dict(), default) + + +@pytest.mark.parametrize("load_config", tuple(ModelConfigType)) +def test_pretrained_config(load_config: ModelConfigType): + config_path = TEST_RESULTS_PATH / "pretrained_config" + pretrained_model_config = GPTModelConfig.from_dict( + { + "base_model": { + "transformer": { + "normalization": {"type": "rms_norm"}, # Nested + "rotary": {"type": "default"}, + "num_layers": 12, # Default + "hidden_size": 1024, # Default + "window_size": 32, # Non-architecture + "ffn_hidden_size": 4096, # Implicit default, default value + "activation_type": "silu", # Implicit default, non-default value + "head_groups": 4, + }, + "tie_word_embeddings": False, + }, + "multi_stage": {"zero_stage": 3}, + "distributed": {"training_dtype": "bfloat16"}, + } + ) + with NoAutoValidate(): + save_config = CheckpointSaveMetadataConfig.from_dict({"format": "fast_llm", "path": config_path}) + save_config.setup(GPTModelConfig) + save_config.validate() + pretrained_model_config.save_metadata(save_config) + + base_model_update = { + "transformer": { + # rotary: Don't override nested. + "normalization": {"implementation": "triton"}, # Update non-default nested + "peft": {"freeze_others": False}, # Update default nested, non-architecture + "hidden_size": 512, # Override, affects derived value (kv channels) + "head_groups": 1, # Override to default + }, + "vocab_size": 1000, + } + pretrained_config = PretrainedGPTModelConfig.from_dict( + { + "model": { + "base_model": base_model_update, + "distributed": {"seed": 1234, "training_dtype": "float16"}, + }, + "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, + } + ) + Assert.eq(pretrained_config.model.base_model.transformer.kv_channels, 64) + serialized_config = pretrained_config.model.to_dict() + expected_config = {"distributed": DistributedConfig().to_dict()} + + if load_config == ModelConfigType.fast_llm: + expected_config["multi_stage"] = {"zero_stage": 3} + expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) + if load_config in (ModelConfigType.architecture, ModelConfigType.fast_llm, ModelConfigType.model): + expected_config["base_model"] = { + "transformer": { + "normalization": {"type": "rms_norm", "implementation": "triton"}, + "rotary": {"type": "default"}, + "peft": {"freeze_others": False}, + "num_layers": 12, + "hidden_size": 512, + "ffn_hidden_size": 4096, + "activation_type": "silu", + "head_groups": 1, + }, + "tie_word_embeddings": False, + "vocab_size": 1000, + } + if load_config != ModelConfigType.architecture: + expected_config["base_model"]["transformer"]["window_size"] = 32 + else: + expected_config["base_model"] = base_model_update + + check_equal_nested(serialized_config, expected_config) From a8684f869a3377f13fbf96c87a7fb850aed52757 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 11 Apr 2025 00:28:08 -0400 Subject: [PATCH 018/122] Knowledge distillation, fix cross-entropy --- fast_llm/functional/config.py | 6 + fast_llm/functional/cross_entropy.py | 191 +++++++++++--------- fast_llm/functional/triton/cross_entropy.py | 123 ++++++++++--- tests/test_triton_kernels.py | 73 ++++++-- 4 files changed, 263 insertions(+), 130 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005e..7284ca071 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -91,3 +91,9 @@ class CrossEntropyImpl(str, enum.Enum): torch = "torch" fused = "fused" triton = "triton" + + +class TargetFormat(enum.StrEnum): + labels = "labels" + logits = "logits" + probabilities = "probabilities" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index e87581f1b..62c61e8e9 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -3,7 +3,7 @@ import torch.autograd from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -12,34 +12,65 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, - logits_scale_factor: float = 1.0, + logits_scale_factor: float, + target_format: TargetFormat, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for this method if the masking index is set to -100. + TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_() - if logits_scale_factor != 1.0: - logits_ *= logits_scale_factor + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + if target_format == TargetFormat.logits: + target = torch.softmax(target, dim=-1) + loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target + ).mean() if grad_output is None: - loss = None + grad = None else: - loss = torch.nn.functional.cross_entropy(logits_, target).mean() loss.backward(torch.full_like(loss, grad_output)) - loss.detach_() - return loss.detach(), logits_.grad.detach().to(logits.dtype) + grad = logits_.grad.detach().to(logits.dtype) + return loss.detach_(), grad + + +# @torch.compile +def _fused_softmax_base( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + logits = logits.float() + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + if group is not None: + all_reduce(logits_max, op=ReduceOp.MAX, group=group) + logits_norm = (logits - logits_max).float() + exp_logits = logits_norm.exp() + sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + if group is not None: + all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) + return logits_norm, exp_logits, sum_exp_logits + + +# @torch.compile +def fused_softmax( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) + return exp_logits / sum_exp_logits -@torch.compile +# @torch.compile def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, - logits_scale_factor: float = 1.0, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -48,82 +79,67 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. - loss_mask = target >= 0 - # Ignore_index can go out of bounds, so set masked values to zero. - target = (target * loss_mask).unsqueeze(1) - logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() - if logits_scale_factor != 1.0: - logits_norm *= logits_scale_factor - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=-1) - - if grad_output is None: - grad = None - else: - exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) - # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits - exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) - if logits_scale_factor != 1.0: - exp_logits *= logits_scale_factor - - grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0) - - per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - return per_sample_loss.mean(), grad + if target_format == TargetFormat.logits: + target = fused_softmax(target, logits_scale_factor, group) - -@torch.compile -def parallel_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - grad_output: float | None, - group: ProcessGroup, - logits_scale_factor: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. - Comes with a noticeable overhead, but reduces memory usage. - """ - # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). - # TODO: Optimize, overlap/combine reductions - loss_mask = target >= 0 - target = target.unsqueeze(1) - - logits_max = torch.max(logits, dim=-1)[0] - all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = logits.sub(logits_max.unsqueeze(dim=-1)).float() - if logits_scale_factor != 1.0: - logits_norm *= logits_scale_factor - - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=-1) - all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - - # Mask the target (fused) - # TODO: Could mask earlier on cpu or overlap with reduce? - vocab_start_index = logits.size(-1) * group.rank() - target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + if target_format == TargetFormat.labels: + target = target.unsqueeze(-1) + loss_mask = target >= 0 + if group is None: + # Keep values within range for scatter and gather ops to work. + target = target * loss_mask + target_mask = None + else: + # Mask the target (fused) + # TODO: Could mask earlier on cpu or overlap with reduce? + vocab_start_index = logits.size(-1) * group.rank() + target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) + target = (target - vocab_start_index) * target_mask + else: + # TODO: Support masking + loss_mask = None + # Target should be tensor-parallel already, no further manipulation needed. + target_mask = None if grad_output is None: grad = None else: - exp_logits1 = exp_logits.scatter( - 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) - ) - exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + if target_format == TargetFormat.labels: + grad_base = exp_logits.scatter_add( + 1, target, -sum_exp_logits if target_mask is None else -target_mask * sum_exp_logits + ) + else: + grad_base = exp_logits - sum_exp_logits * target + + grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: - exp_logits2 *= logits_scale_factor + grad *= logits_scale_factor + grad = grad.to(logits.dtype) + if loss_mask is not None: + grad = torch.where(loss_mask, grad.to(logits.dtype), 0) + + # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + if target_format == TargetFormat.labels: + predicted_logits = logits_norm.gather(1, target) + if group is not None: + predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + else: + predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0) + per_sample_loss = sum_exp_logits.log() - predicted_logits + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask - predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) - all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - per_sample_loss = sum_exp_logits.log().sub(predicted_logits) * loss_mask + loss = per_sample_loss.mean() + if target_format != TargetFormat.labels and group is not None: + all_reduce(loss, op=ReduceOp.MEAN, group=group) - return per_sample_loss.mean(), grad + return loss, grad _CROSS_ENTROPY_IMPLEMENTATIONS = { @@ -134,12 +150,13 @@ def parallel_cross_entropy_forward_backward( def cross_entropy_forward_backward( - logits, - target, + logits: torch.Tensor, + target: torch.Tensor, grad_output: float | None, - group: ProcessGroup | None, + group: ProcessGroup | None = None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -147,12 +164,18 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return parallel_cross_entropy_forward_backward( - logits, target, grad_output, group, logits_scale_factor=logits_scale_factor + return fused_cross_entropy_forward_backward( + logits, target, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor + logits, target, grad_output, logits_scale_factor, target_format ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 8b6228498..321bd0fa2 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,6 +1,6 @@ import torch -from fast_llm.functional.config import TritonConfig +from fast_llm.functional.config import TargetFormat, TritonConfig from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit @@ -11,9 +11,9 @@ def triton_cross_entropy_forward_backward_kernel( grad_logits_ptr, losses_ptr, grad_losses, - n_cols, - logits_stride_0, - grad_logits_stride_0, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, block_size: tl_constexpr, ): @@ -33,27 +33,78 @@ def triton_cross_entropy_forward_backward_kernel( label_idx = tl.load(labels_ptr + block_idx) - label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if label_idx < 0: + # Loss mask loss = 0.0 else: + label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + label_logits *= logits_scale_factor loss = tl.log(sum_exp_logits) + max_logits - label_logits tl.store(losses_ptr + block_idx, loss) - grad_logits_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + if grad_losses is not None: + if label_idx < 0: + grad_losses = 0.0 + grad_base = exp_logits / sum_exp_logits + grad_logits = grad_losses * tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + if logits_scale_factor != 1.0: + grad_logits *= logits_scale_factor + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + + +@triton_jit() +def triton_cross_entropy_from_distribution_forward_backward_kernel( + logits_ptr, + target_ptr, + grad_logits_ptr, + losses_ptr, + grad_losses, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + grad_logits_stride_0: tl_constexpr, + logits_scale_factor: tl_constexpr, + from_logits: tl_constexpr, + block_size: tl_constexpr, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, block_size) - label_idx = tl.load(labels_ptr + block_idx) - exp_logits = exp_logits / sum_exp_logits + logits_ptr = logits_ptr + block_idx * logits_stride_0 + mask = col_offsets < n_cols + + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if from_logits: + max_target_logits = tl.max(logits, 0) + exp_target_logits = tl.exp(target - max_target_logits) + sum_exp_target_logits = tl.sum(exp_target_logits, 0) + target = exp_target_logits / sum_exp_target_logits + + # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) + loss = tl.log(sum_exp_logits) - tl.sum(target * logits, 0) + tl.store(losses_ptr + block_idx, loss) + + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - if label_idx < 0: - grad_losses = 0.0 - grad_logits = grad_losses * tl.where(col_offsets == label_idx, exp_logits - 1.0, exp_logits) - tl.store(grad_logits_ptr + col_offsets, grad_logits, mask=mask) + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) def triton_cross_entropy_forward_backward( - logits, target, grad_output: float | None, logits_scale_factor: float = 1.0 + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -72,18 +123,34 @@ def triton_cross_entropy_forward_backward( num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? - grad_logits = torch.empty_like(logits) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( - logits, - target, - grad_logits, - losses, - 1 if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(0), - grad_logits.stride(0), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - ) - return losses.mean(), None if grad_output is None else grad_logits + grad_logits = None if grad_output is None else torch.empty_like(logits) + if target_format == TargetFormat.labels: + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None if grad_output is None else grad_output / n_rows, + n_cols, + logits.stride(0), + None if grad_output is None else grad_logits.stride(0), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) + else: + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None if grad_output is None else grad_output / n_rows, + n_cols, + logits.stride(0), + None if grad_output is None else grad_logits.stride(0), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + from_logits=target_format == TargetFormat.logits, + ) + return losses.mean(), grad_logits diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index e61c2d51c..e52073aa7 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -1,14 +1,20 @@ import pytest import torch -from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType, TritonConfig +from fast_llm.functional.config import ( + MAX_DROPLESS_BLOCK_SIZE_ROW, + ActivationType, + CrossEntropyImpl, + TargetFormat, + TritonConfig, +) +from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, ) from fast_llm.functional.triton.adam import triton_adam -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.functional.triton.mlp import ( torch_mlp_activation, triton_mlp_activation_backward, @@ -184,24 +190,55 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @requires_cuda -def test_triton_cross_entropy(): +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor"), + ( + (8192, 1.0, 1.0), + (8192, None, 1.0), + (8192, 1.0, 4.0), + (8192, 4.0, 1.0), + (65536, 1.0, 1.0), + (131072, 1.0, 1.0), + ), +) +@pytest.mark.parametrize("target_format", (TargetFormat.labels,)) # TargetFormat.logits, TargetFormat.probabilities)) +def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_format): + # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - logits = torch.randn(1024, 8192, dtype=torch.bfloat16, device="cuda", requires_grad=True) - labels = torch.randint(0, 8192, (1024,), dtype=torch.int64, device="cuda") - - from fast_llm.functional.cross_entropy import ( - fused_cross_entropy_forward_backward, - torch_cross_entropy_forward_backward, - ) - - c1, g1 = torch_cross_entropy_forward_backward(logits, labels, 1) - c2, g2 = fused_cross_entropy_forward_backward(logits, labels, 1) - c3, g3 = triton_cross_entropy_forward_backward(logits, labels, 1) + logits = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda", requires_grad=True) + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + else: + target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + + kwargs = { + "logits": logits, + "target": target, + "grad_output": grad_output, + "logits_scale_factor": logits_scale_factor, + "target_format": target_format, + } + # Torch serves as the reference implementation. + out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) + + out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) + Assert.rms_close(out_fused, out_torch, 5e-3) + if grad_output is None: + assert grad_torch is None + assert grad_fused is None + else: + Assert.rms_close(grad_fused, grad_torch, 5e-3) - Assert.rms_close(c2, c3, 5e-3) - Assert.rms_close(c1, c3, 5e-3) - Assert.rms_close(g1, g3, 5e-3) - Assert.rms_close(g2, g3, 5e-3) + if target_format == TargetFormat.probabilities or num_columns > 65536: + with pytest.raises(AssertionError): + cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) + else: + out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) + Assert.rms_close(out_triton, out_torch, 5e-3) + if grad_output is None: + assert grad_triton is None + else: + Assert.rms_close(grad_triton, grad_torch, 5e-3) @requires_cuda From b781729d684b4c2415585277f333afc75999874d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sun, 13 Apr 2025 11:15:17 -0400 Subject: [PATCH 019/122] Fixes, distillation --- fast_llm/functional/cross_entropy.py | 4 +- fast_llm/functional/triton/cross_entropy.py | 32 +++++++++------ fast_llm/layers/language_model/config.py | 11 +++++- fast_llm/layers/language_model/head.py | 43 +++++++++++++-------- fast_llm/models/gpt/config.py | 7 +++- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 2 +- tests/test_triton_kernels.py | 14 +++++-- 8 files changed, 78 insertions(+), 37 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 62c61e8e9..0a6118328 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -25,6 +25,8 @@ def torch_cross_entropy_forward_backward( with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: + if logits_scale_factor != 1.0: + target = target * logits_scale_factor target = torch.softmax(target, dim=-1) loss = torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target @@ -63,7 +65,7 @@ def fused_softmax( return exp_logits / sum_exp_logits -# @torch.compile +@torch.compile def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 321bd0fa2..62ed2e0ee 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -62,6 +62,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -70,33 +71,40 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, block_size) - logits_ptr = logits_ptr + block_idx * logits_stride_0 mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( + tl.float32 + ) if logits_scale_factor != 1.0: logits *= logits_scale_factor max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) + logits_norm = logits - max_logits + exp_logits = tl.exp(logits_norm) sum_exp_logits = tl.sum(exp_logits, 0) - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( + tl.float32 + ) if from_logits: - max_target_logits = tl.max(logits, 0) + if logits_scale_factor != 1.0: + target *= logits_scale_factor + max_target_logits = tl.max(target, 0) exp_target_logits = tl.exp(target - max_target_logits) sum_exp_target_logits = tl.sum(exp_target_logits, 0) target = exp_target_logits / sum_exp_target_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(target * logits, 0) + loss = tl.log(sum_exp_logits) - tl.sum(target * logits_norm, 0) tl.store(losses_ptr + block_idx, loss) - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - if logits_scale_factor != 1.0: - exp_logits *= logits_scale_factor - grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + if grad_losses is not None: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + if logits_scale_factor != 1.0: + grad_logits *= logits_scale_factor + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) def triton_cross_entropy_forward_backward( @@ -117,7 +125,6 @@ def triton_cross_entropy_forward_backward( assert logits.is_contiguous() assert target.is_contiguous() n_rows, n_cols = logits.shape - assert target.shape == (n_rows,) block_size = triton.next_power_of_2(n_cols) assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) @@ -147,6 +154,7 @@ def triton_cross_entropy_forward_backward( None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(0), + target.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3bd796033..22cce43a9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -151,6 +151,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) # Tensor-parallel word embeddings # (Default init std is different, dropout won't match, needs seq_first = False.) # (disable to allow for sequence-parallel embeddings and logits, better for larger models) @@ -195,6 +201,9 @@ def _validate(self) -> None: self.init_method_max_embed = self.transformer.init_method_max if self.init_method_min_embed is None: self.init_method_min_embed = self.transformer.init_method_min + super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) - super()._validate() + if self.distillation_model is not None: + if self.prediction_heads > 1: + raise NotImplementedError("Multi-token prediction not supported with distillation.") diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1286121c3..04e4020f3 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss @@ -145,12 +145,22 @@ def forward( def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None - # MTP: Shift the labels - labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None + target = kwargs.get( + LanguageModelKwargs.labels + if self._config.distillation_model is None + else f"{self._config.distillation_model}_logits" + ) + if target is not None: + if self._config.distillation_model is None: + # Target is labels (token ids) + # MTP: Shift the labels + target = target[:, self._prediction_distance :].flatten() + else: + # Target is reference model logits. + target = target.flatten(0, -2) if self._sequence_parallel_logits: - labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) - do_grad = labels is not None and self.training + target = split_op(target, self._tensor_space.distributed.tensor_group, 0) + do_grad = target is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): # MTP: truncate the input @@ -166,7 +176,7 @@ def _forward_backward( output_weights = self._get_output_weights(kwargs) loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), labels, output_weights, grad_output, kwargs, losses + ln_output.detach(), target, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -185,29 +195,29 @@ def _get_output_weights(self, kwargs: dict) -> torch.Tensor: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - labels: torch.Tensor | None, + target: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or labels is None: + if self._cross_entropy_splits is None or target is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, labels, weight, grad_output, kwargs, losses + input_, target, weight, grad_output, kwargs, losses ) - if labels is None: + if target is None: # TODO: Make a proper way of returning the model output. kwargs["logits"] = loss return None, None else: loss = None # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - split_size = div(labels.numel(), self._cross_entropy_splits) + split_size = div(target.size(0), self._cross_entropy_splits) grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) for logit_input_, labels_, logit_input_grad_ in zip( - logit_input.split(split_size), labels.split(split_size), logit_input_grad.split(split_size) + logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, @@ -231,7 +241,7 @@ def _logits_cross_entropy_forward_backward_split( def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, - labels: torch.Tensor | None, + target: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -285,15 +295,16 @@ def _logits_cross_entropy_forward_backward( scale=self._logits_scale_factor, ) - if labels is None: + if target is None: return logits * self._logits_scale_factor, None loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), - labels, + target, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, ) # TODO: de-allocate earlier. del logits diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 19c8e6ac6..09c3e757d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -139,9 +139,14 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.cross_entropy_splits) - super()._validate() + Assert.none(reference_model.model.base_model.distillation_model) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c672b2165..c0eabc453 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -269,7 +269,7 @@ def preprocess( labels = batch.token_ids[sequence_offset : sequence_k + 1] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + labels = batch.token_ids[:, sequence_offset : sequence_k + 1].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 0a4ce007f..2499676ce 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -86,7 +86,7 @@ class Assert: @staticmethod def eq(x, *args, msg=None): for arg in args: - assert x == arg, f"{x} != {arg} " + f"| {msg}" if msg else "" + assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") @staticmethod def is_(x, y): diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index e52073aa7..b6970ddfe 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -190,6 +190,7 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @requires_cuda +@pytest.mark.slow @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor"), ( @@ -201,15 +202,20 @@ def test_triton_mlp_activation(gated, activation_type, recompute): (131072, 1.0, 1.0), ), ) -@pytest.mark.parametrize("target_format", (TargetFormat.labels,)) # TargetFormat.logits, TargetFormat.probabilities)) +@pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_format): # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - logits = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda", requires_grad=True) + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 if target_format == TargetFormat.labels: target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + logits = (torch.nn.functional.one_hot(target, num_columns) + logits_var).requires_grad_() else: target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + logits = (target + logits_var).requires_grad_() + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) kwargs = { "logits": logits, @@ -229,16 +235,16 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_for else: Assert.rms_close(grad_fused, grad_torch, 5e-3) - if target_format == TargetFormat.probabilities or num_columns > 65536: + if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) else: out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - Assert.rms_close(out_triton, out_torch, 5e-3) if grad_output is None: assert grad_triton is None else: Assert.rms_close(grad_triton, grad_torch, 5e-3) + Assert.rms_close(out_triton, out_torch, 5e-3) @requires_cuda From db6504b0546b8462e79173e5f03e960c2d694d6d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 14:55:18 -0400 Subject: [PATCH 020/122] fixes --- fast_llm/config.py | 1 + fast_llm/engine/multi_stage/config.py | 3 ++- fast_llm/engine/training/config.py | 8 +++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c889658..443925ce2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -30,6 +30,7 @@ def __enter__(self): global _AUTO_VALIDATE self._old_value = _AUTO_VALIDATE _AUTO_VALIDATE = False + return _AUTO_VALIDATE def __exit__(self, exc_type, exc_val, exc_tb): global _AUTO_VALIDATE diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e6de074f4..ee94ce61a 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -336,7 +336,8 @@ def _validate(self) -> None: self.pretrained.setup(self.model) self.pretrained.validate() if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + with NoAutoValidate(): + self.model = self.model.from_pretrained(self.pretrained, default=self.model) self._setup() super()._validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4f5164b0b..9819ced35 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -383,17 +383,15 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) - self.model.validate() + for reference_model in self.reference_models.values(): + _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) + super()._validate() if self.reference_models: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) - - for reference_model in self.reference_models.values(): - _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) - super()._validate() if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() From cff9892d44a9380a992f33692500ed7e08191824 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 18:11:05 -0400 Subject: [PATCH 021/122] fixes --- fast_llm/engine/inference/huggingface.py | 4 ++- tests/test_checkpoint.py | 33 +++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 75aea9dd0..196310b4d 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -60,7 +60,9 @@ def from_pretrained( updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) + fast_llm_model = cls.runner_class.model_class.from_pretrained( + pretrained_model_name_or_path, updates, mode=mode + ) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 4171581a0..0c5e177d2 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -263,6 +263,7 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, + load_config=ModelConfigType.model, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_model_configs(config, model.config) @@ -276,19 +277,25 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -302,13 +309,25 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_0", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_1", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) _compare_architectures(config_ref, model.config) _compare_architectures(config_ref, config_alt) @@ -324,12 +343,14 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ), mode=StageMode.weights, ) @@ -337,6 +358,7 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -353,6 +375,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) test_input = torch.randint( @@ -364,6 +387,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) errors = [] @@ -479,6 +503,7 @@ def test_load_distributed_checkpoint_dp2(): pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) From b67006a0ea650145357e023d6c4f517a4b7de2c2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 20:33:57 -0400 Subject: [PATCH 022/122] fixes --- fast_llm/engine/distributed/config.py | 3 ++- fast_llm/engine/training/config.py | 7 +++---- fast_llm/functional/triton/cross_entropy.py | 2 +- tests/test_triton_kernels.py | 13 +++++++------ 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8f04a7054..66d89e1a0 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -293,7 +293,6 @@ def _validate(self) -> None: if self.reference_config.reference_config is not None: self.reference_config = self.reference_config.reference_config assert self.reference_config.reference_config is None - self.compare(self.reference_config, ValueError) self.distributed_dims = self.reference_config.distributed_dims else: self.distributed_dims = {} @@ -368,6 +367,8 @@ def _validate(self) -> None: super()._validate() + if self.reference_config is not None: + self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 578937fb5..8b4cadc31 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -429,13 +429,12 @@ def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelC def new_setup(): # Make sure the distributed config isn't set - # TODO!!!!!!!!!!!!!: Uncomment after #205 - # pretrained.model.distributed.validate() - # Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + pretrained.model.distributed.validate() + Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) with NoAutoValidate(): pretrained.model.distributed = distributed.to_copy() # Allow sharing the `Distributed` instance. pretrained.model.distributed.reference_config = distributed old_setup() - pretrained._setup = new_setup + object.__setattr__(pretrained, "_setup", new_setup) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 62ed2e0ee..d825af034 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -96,7 +96,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = exp_target_logits / sum_exp_target_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(target * logits_norm, 0) + loss = tl.log(sum_exp_logits) - tl.sum(tl.where(mask, target * logits_norm, 0), 0) tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index b6970ddfe..1ace81d76 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -194,12 +194,13 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor"), ( - (8192, 1.0, 1.0), - (8192, None, 1.0), - (8192, 1.0, 4.0), - (8192, 4.0, 1.0), - (65536, 1.0, 1.0), - (131072, 1.0, 1.0), + (8192, 1.0, 1.0), # Simple + (5000, 1.0, 1.0), # Not a power of 2 + (5000, None, 1.0), # No grad + (5000, 1.0, 4.0), # Loss scaling + (5000, 4.0, 1.0), # Grad scaling + (65536, 1.0, 1.0), # Max block size + (65537, 1.0, 1.0), # Above max block size ), ) @pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) From 20141081adc406e3315cce9825df5b58b9630258 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 11:13:00 -0400 Subject: [PATCH 023/122] Add constraints --- fast_llm/models/gpt/config.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 09c3e757d..f0c314d61 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -144,9 +144,37 @@ def _validate(self) -> None: Assert.empty(self.reference_models) else: Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + # TODO: Support distinct preprocessing + reference_model.model.base_model.transformer.rotary.compare( + self.model.base_model.transformer.rotary, + NotImplementedError, + ) + Assert.eq( + reference_model.model.base_model.use_absolute_position_embeddings, + self.model.base_model.use_absolute_position_embeddings, + ) + if reference_model.model.base_model.use_absolute_position_embeddings: + assert self.model.base_model.use_absolute_position_embeddings + Assert.geq( + reference_model.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length + ) + use_flash = reference_model.model.base_model.transformer.do_use_flash_attention( + reference_model.model.distributed + ) + Assert.eq(use_flash, self.model.base_model.transformer.do_use_flash_attention(self.model.distributed)) + if use_flash: + Assert.eq( + reference_model.model.base_model.transformer.window_size, + self.model.base_model.transformer.window_size, + ) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: From fa3d556f371a75c29da53e082474c47a557dfa29 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 12:35:16 -0400 Subject: [PATCH 024/122] Add constraints --- fast_llm/models/gpt/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b02dd7be2..705e9918a 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -186,6 +186,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.distillation_model is not None: + # TODO: Support loss masking for distillation? + assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From 6c2c887b47dd1220dc626ac6edb817e004e0173d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 14:30:09 -0400 Subject: [PATCH 025/122] Separate reference model preprocessing --- fast_llm/engine/base_model/base_model.py | 27 +++++++++---- fast_llm/engine/training/trainer.py | 9 +---- fast_llm/models/gpt/config.py | 23 ----------- fast_llm/models/gpt/model.py | 50 ++++++++++++++++-------- fast_llm/models/gpt/trainer.py | 23 ----------- 5 files changed, 54 insertions(+), 78 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 76da0f9b0..3835b1909 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,13 +6,16 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.inference.runner import InferenceRunner + class Module(torch.nn.Module, abc.ABC): """ """ @@ -80,6 +83,7 @@ def get_layers(self) -> list[Layer]: class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig + _is_setup: bool = False def __init__( self, @@ -96,6 +100,16 @@ def __init__( # Rename to the parameter full name value.tensor_name = key + # Reference models + # TODO: Add basic handling (preprocessor) in this class. + self._reference_models: dict[str, "InferenceRunner"] = {} + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._tensor_space.distributed_config) + self._tensor_space.setup(distributed) + self._is_setup = True + @classmethod def architecture_cls(cls) -> type[BaseModelArchitectureConfig]: return cls.config_class.architecture_class @@ -104,10 +118,6 @@ def architecture_cls(cls) -> type[BaseModelArchitectureConfig]: def get_layers(self) -> list[Layer]: pass - @abc.abstractmethod - def setup(self, distributed: Distributed) -> None: - pass - @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: pass @@ -136,6 +146,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass - def add_preprocessor(self, preprocessor: Preprocessor): - # TODO: Generalize preprocessors. - raise NotImplementedError() + def add_reference_model(self, name: str, inference_runner: InferenceRunner) -> None: + assert name not in self._reference_models + assert not self._is_setup + self._reference_models[name] = inference_runner diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 66f1ad869..abd8f9dc0 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -12,11 +12,9 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer @@ -55,9 +53,7 @@ def __init__(self, config: TrainerConfig): self._reference_models[name] = self._config.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) - self._multi_stage.base_model.add_preprocessor( - self._get_reference_model_preprocessor(name, self._reference_models[name]) - ) + self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) phase: PhaseType self._runner = ScheduleRunner( @@ -562,6 +558,3 @@ def _get_last_checkpoint(self) -> int | None: def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats pass - - def _get_reference_model_preprocessor(self, name: str, inference_runner: InferenceRunner) -> Preprocessor: - raise NotImplementedError() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 705e9918a..e6230116d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -195,29 +195,6 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - # TODO: Support distinct preprocessing - reference_model.model.base_model.transformer.rotary.compare( - self.model.base_model.transformer.rotary, - NotImplementedError, - ) - Assert.eq( - reference_model.model.base_model.use_absolute_position_embeddings, - self.model.base_model.use_absolute_position_embeddings, - ) - if reference_model.model.base_model.use_absolute_position_embeddings: - assert self.model.base_model.use_absolute_position_embeddings - Assert.geq( - reference_model.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length - ) - use_flash = reference_model.model.base_model.transformer.do_use_flash_attention( - reference_model.model.distributed - ) - Assert.eq(use_flash, self.model.base_model.transformer.do_use_flash_attention(self.model.distributed)) - if use_flash: - Assert.eq( - reference_model.model.base_model.transformer.window_size, - self.model.base_model.transformer.window_size, - ) @classmethod def _from_dict( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 55b08f2ea..77faa8a3d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames @@ -41,7 +40,6 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): """ config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - _is_setup: bool = False _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _mask: torch.Tensor @@ -59,6 +57,7 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) @@ -113,12 +112,6 @@ def get_layers(self) -> list[Layer]: *self.get_output_layers(), ] - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -186,12 +179,20 @@ def preprocess_meta( TransformerKwargs.sequence_q_dim: sequence_q_dim, } - preprocessed_meta = [] - for sequence_k_past in range( + sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, - ): + ) + reference_preprocessed_metas = {} + for name, reference_model in self._reference_models.items(): + reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( + batch_meta, phase + ) + Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) + + preprocessed_meta = [] + for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) @@ -209,6 +210,15 @@ def preprocess_meta( ) for preprocessor in self._preprocessors: preprocessor.preprocess_meta(kwargs) + reference_kwargs = {} + for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): + reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] + for key, value in common_kwargs.items(): + Assert.eq(reference_kwargs_[key], value) + Assert.eq(reference_kwargs_[TransformerKwargs.sequence_k_dim], sequence_k_dim) + reference_kwargs[name] = reference_kwargs_ + kwargs["reference_models"] = reference_kwargs + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -237,13 +247,22 @@ def preprocess( dtype=torch.int64, non_blocking=True, ) + + reference_logits = {} + for name, reference_model in self._reference_models.items(): + reference_logits[name] = [] + for _, kwargs_meta in preprocessed_meta: + reference_tokens, reference_kwargs = kwargs_meta["reference_models"][name] + reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) + reference_logits[name].append(reference_kwargs["logits"]) + if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None - for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): + for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] @@ -286,6 +305,9 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + for name, reference_logits_ in reference_logits.items(): + kwargs[f"{name}_logits"] = reference_logits_[i] + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) @@ -361,10 +383,6 @@ def loss_defs(self) -> list[LossDef]: ) return loss_defs - def add_preprocessor(self, preprocessor: Preprocessor): - assert not self._is_setup - self._preprocessors.append(preprocessor) - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a269f5a63..57327b272 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,33 +3,13 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTTrainerConfig -from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) -class GPTReferenceModelPreprocessor(Preprocessor): - def __init__(self, name: str, inference_runner: GPTInferenceRunner): - self._name = name - self._inference_runner = inference_runner - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - # TODO: Fix random state/iteration. - preprocess_kwargs = kwargs.copy() - del preprocess_kwargs[LanguageModelKwargs.labels] - self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) - # TODO: Improve. - kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] - - class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig @@ -101,6 +81,3 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, hardware_flops = flops_per_iteration + 7 / 6 * attn_flops ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 return model_tflops / ratio, hardware_flops / ratio - - def _get_reference_model_preprocessor(self, name: str, inference_runner: GPTInferenceRunner) -> Preprocessor: - return GPTReferenceModelPreprocessor(name, inference_runner) From 67f9db637242498329f932dcf0650b4829e6599e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 16:10:19 -0400 Subject: [PATCH 026/122] fix --- fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/multi_stage/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 3835b1909..2dbf8cc81 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -146,7 +146,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass - def add_reference_model(self, name: str, inference_runner: InferenceRunner) -> None: + def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: assert name not in self._reference_models assert not self._is_setup self._reference_models[name] = inference_runner diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 69bf3695a..e2d04f80f 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,7 +30,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.model import HuggingfacePreTrainedModel + from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) From 537deca2a96353b149908aa14c0dbeb8b2188bc3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 17 Apr 2025 15:47:47 -0400 Subject: [PATCH 027/122] fix --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0a6118328..1eb6c8c04 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -112,7 +112,7 @@ def fused_cross_entropy_forward_backward( # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if target_format == TargetFormat.labels: grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -target_mask * sum_exp_logits + 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) ) else: grad_base = exp_logits - sum_exp_logits * target From d2b3154f22c574ec4b36b4cc706a42964b55684e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Apr 2025 10:49:24 -0400 Subject: [PATCH 028/122] misc --- fast_llm/layers/language_model/head.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8edf1bc84..c2974415d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -50,9 +50,7 @@ def __init__( self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings - ) + self._sequence_parallel_logits = self._sequence_parallel and not self._parallel_embeddings self._cross_entropy_splits = config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings @@ -215,12 +213,12 @@ def _logits_cross_entropy_forward_backward_split( grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) - for logit_input_, labels_, logit_input_grad_ in zip( + for logit_input_, target_, logit_input_grad_ in zip( logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, - labels_, + target_, weight, grad_output, kwargs, From a0ba05161cd03443b2b247e6db98322e5285b8b4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Apr 2025 15:50:50 -0400 Subject: [PATCH 029/122] fixes --- fast_llm/engine/config_utils/tensor_space.py | 3 ++ fast_llm/engine/training/config.py | 30 ++++++++-------- fast_llm/models/gpt/model.py | 36 +++++++++++++------- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0384fdacd..5020bc650 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -27,6 +27,9 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + @property def name(self) -> str: return self._name diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 8b4cadc31..1e990e9c8 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,7 +23,6 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -386,7 +385,7 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): - _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) + self._add_reference_distributed_to_pretrained(reference_model) super()._validate() if self.reference_models: # TODO: Add support. @@ -396,6 +395,8 @@ def _validate(self) -> None: Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() + for reference_model in self.reference_models.values(): + assert reference_model.model.distributed.reference_config is self.model.distributed def _setup(self): super()._setup() @@ -423,18 +424,17 @@ def runnable(): return runnable + def _add_reference_distributed_to_pretrained(self, pretrained: PretrainedFastLLMModelConfig): + old_setup = pretrained._setup -def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelConfig, distributed: DistributedConfig): - old_setup = pretrained._setup - - def new_setup(): - # Make sure the distributed config isn't set - pretrained.model.distributed.validate() - Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) - with NoAutoValidate(): - pretrained.model.distributed = distributed.to_copy() - # Allow sharing the `Distributed` instance. - pretrained.model.distributed.reference_config = distributed - old_setup() + def new_setup(): + # Make sure the distributed config isn't set + pretrained.model.distributed.validate() + Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + with NoAutoValidate(): + pretrained.model.distributed = self.model.distributed.to_copy() + # Allow sharing the `Distributed` instance. + pretrained.model.distributed.reference_config = self.model.distributed + old_setup() - object.__setattr__(pretrained, "_setup", new_setup) + object.__setattr__(pretrained, "_setup", new_setup) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c955eec5c..5c408a600 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -125,7 +125,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= 1 + sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) @@ -187,7 +187,7 @@ def preprocess_meta( reference_preprocessed_metas = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( - batch_meta, phase + batch_meta, PhaseType.inference ) Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) @@ -213,9 +213,14 @@ def preprocess_meta( reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] - for key, value in common_kwargs.items(): - Assert.eq(reference_kwargs_[key], value) - Assert.eq(reference_kwargs_[TransformerKwargs.sequence_k_dim], sequence_k_dim) + for key in ( + TransformerKwargs.sequence_first, + TransformerKwargs.hidden_dims, + TransformerKwargs.sequence_length, + TransformerKwargs.sequence_q_dim, + TransformerKwargs.sequence_k_dim, + ): + Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs @@ -249,13 +254,21 @@ def preprocess( non_blocking=True, ) - reference_logits = {} + reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): - reference_logits[name] = [] - for _, kwargs_meta in preprocessed_meta: - reference_tokens, reference_kwargs = kwargs_meta["reference_models"][name] + reference_preprocessed_meta = [ + (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta + ] + + reference_batch = reference_model.fast_llm_model.base_model.preprocess( + batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + ) + + # TODO: Do things work with >1? + Assert.eq(len(reference_batch), len(preprocessed_meta), 1) + for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - reference_logits[name].append(reference_kwargs["logits"]) + reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. @@ -308,8 +321,7 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - for name, reference_logits_ in reference_logits.items(): - kwargs[f"{name}_logits"] = reference_logits_[i] + kwargs.update(reference_logits[i]) for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) From 9ddfb69115d6d84d960e76c36d4e65994eae76cc Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Apr 2025 21:43:23 +0000 Subject: [PATCH 030/122] add per-layer lr-scale --- fast_llm/layers/common/config.py | 3 ++- fast_llm/layers/common/normalization.py | 5 +++++ fast_llm/layers/language_model/config.py | 13 +++++++++++++ fast_llm/layers/language_model/embedding.py | 2 ++ fast_llm/layers/language_model/head.py | 1 + fast_llm/layers/transformer/attention.py | 13 ++++++++----- fast_llm/layers/transformer/config.py | 6 ++++++ .../layers/transformer/mixture_of_experts.py | 9 ++++++--- fast_llm/layers/transformer/mlp.py | 15 ++++++++++----- fast_llm/layers/transformer/transformer.py | 5 +++-- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 18 ++++++++++++++++++ 12 files changed, 75 insertions(+), 17 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 71c15c9b8..6e596751d 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -82,7 +82,7 @@ class NormalizationConfig(NormalizationArchitectureConfig, BaseModelConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ @@ -91,6 +91,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": "eps": self.epsilon, "implementation": self.implementation, "zero_centered": self.zero_centered, + "lr_scale": lr_scale, } if self.initialization_range: mean = 0 if self.zero_centered else 1 diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 04123014e..984778f83 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -152,6 +152,7 @@ def __init__( weight_init_method=None, bias_init_method=init_zeros_, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -190,12 +191,14 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape @@ -230,6 +233,7 @@ def __init__( implementation: NormalizationImplementation = NormalizationImplementation.auto, weight_init_method=None, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -263,6 +267,7 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=True, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b4b4e187c..c99ee4f6a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -202,6 +202,19 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..e0386d8df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c2974415d..1153fb2c2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -102,6 +102,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5c..54fff2286 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -84,7 +84,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer @@ -110,6 +110,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, @@ -118,7 +121,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -127,7 +130,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -139,7 +142,7 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf409e773..c13c2a093 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -636,6 +636,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.", diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f4..49778c63f 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -21,7 +21,7 @@ from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." @@ -59,6 +59,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + self.router = Linear( tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), @@ -66,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max ), - lr_scale=config.router_lr_scale, + lr_scale=router_lr_scale, ) dropless_moe = config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f9..c4d8afdc7 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -10,13 +10,14 @@ from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): super().__init__() self._name = name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -38,6 +39,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, @@ -45,7 +50,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) # PEFT. @@ -64,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 92df18937..9e1e0bcfa 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -39,8 +39,9 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) self._create_mixer() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a7ec58d67..873c8f80e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -80,7 +80,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=self._config.transformer.num_layers, + layer_index=self._config.transformer.num_layers + i, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index a8c5eac61..c524a315d 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -326,3 +326,21 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def get_lr_scale( + lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None +) -> float | None | tuple[float | None, ...]: + """ + Combine module and layer lr_scale. + If one is None, return the other. + """ + if lr_scale is None: + return layer_lr_scale + if layer_lr_scale is None: + return lr_scale + if isinstance(lr_scale, float): + return lr_scale * layer_lr_scale + if isinstance(lr_scale, tuple): + return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) + raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") From 5e282cc0369d40ecbd81963fecadd7c699317b2a Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 28 Apr 2025 16:24:17 +0000 Subject: [PATCH 031/122] modeling mtp llamba --- .../ssm/external/configuration_mtp_llamba.py | 94 +++++ .../models/ssm/external/discrete_mamba2.py | 382 +++++++++++++++++ .../ssm/external/modeling_mtp_llamba.py | 389 ++++++++++++++++++ 3 files changed, 865 insertions(+) create mode 100644 fast_llm/models/ssm/external/configuration_mtp_llamba.py create mode 100644 fast_llm/models/ssm/external/discrete_mamba2.py create mode 100644 fast_llm/models/ssm/external/modeling_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/configuration_mtp_llamba.py new file mode 100644 index 000000000..b8173b733 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_mtp_llamba.py @@ -0,0 +1,94 @@ +from enum import Enum + +from transformers.configuration_utils import PretrainedConfig + + +class StateUpdateKernel(Enum): + ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet + cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html + ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification + + +class MTPLlambaConfig(PretrainedConfig): + r"""Configuration class for the CustomMamba model. + + This configuration is used to instantiate the CustomMamba model according to the specified arguments, + defining the model architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the model. + tie_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + pad_vocab_size_multiple (`int`, *optional*, defaults to 8): + Pad the vocabulary size up to the next multiple of this value. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether the LM head includes a bias term. + d_model (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + lm_head_prenorm (`str`, *optional*, defaults to "rms"): + Normalization type for LM head. + n_layer (`int`, *optional*, defaults to 32): + Number of layers in the model. + resid_dropout (`float`, *optional*, defaults to 0.0): + Dropout rate for residual connections. + norm_epsilon (`float`, *optional*, defaults to 1e-5): + Epsilon value used for normalization layers. + mlp_cfg (`dict`, *optional*): + Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. + ssm_cfg (`dict`, *optional*): + Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. + + """ + + model_type = "llamba" + + def __init__( + self, + vocab_size: int, + d_model: int, + tie_embeddings: bool = False, + pad_vocab_size_multiple: int = 8, + lm_head_bias: bool = False, + n_layer: int = 32, + resid_dropout: float = 0.0, + norm_epsilon: float = 1e-5, + mlp_cfg: dict = None, + ssm_cfg: dict = None, + prediction_heads=1, + state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.tie_embeddings = tie_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.lm_head_bias = lm_head_bias + self.d_model = d_model + self.n_layer = n_layer + self.resid_dropout = resid_dropout + self.norm_epsilon = norm_epsilon + self.prediction_heads = prediction_heads + assert ( + state_update_kernel != StateUpdateKernel.ssu_verification + ), "Only chunk scan and standard modes are supported for now" + self.state_update_kernel = state_update_kernel + + # MLP (Multi-Layer Perceptron) Config + self.mlp_cfg = mlp_cfg or { + "intermediate_size": 14336, + "bias": False, + "act_fn": "silu", + } + + # SSM (State Space Model) Config + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + } diff --git a/fast_llm/models/ssm/external/discrete_mamba2.py b/fast_llm/models/ssm/external/discrete_mamba2.py new file mode 100644 index 000000000..bb8afaa7d --- /dev/null +++ b/fast_llm/models/ssm/external/discrete_mamba2.py @@ -0,0 +1,382 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + +from .configuration_mtp_llamba import StateUpdateKernel + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + + +class DiscreteMamba2(nn.Module): + """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" + + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + verification_mode: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". + + Other options are all experimental and should not need to be configured. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + self.inference_mode = verification_mode + assert verification_mode in [ + StateUpdateKernel.cs, + StateUpdateKernel.standard, + ], "Only chunk scan and standard selective scan are supported for now" + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + """Returns the output dimension of the model.""" + return self.d_model + + @property + def state_to_tensor(self): + """Returns the state of the model as a tensor.""" + return self.layer.state_to_tensor + + def forward(self, u, inference_params=None, **kwargs): + """ + Args: + u: (B, L, D), + inference_params: dict.. Here we assume it contains a mask tensor of shape (B, L) with 1s for valid tokens and 0s for no-op tokens. + + Returns: + outputs: dict. + outputs["hidden_states"]: (B, L, D). + outputs["state"]: inference cache. + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + + if ( + state is not None + and inference_params.seqlen_offset > 0 # meaning we are in the middle of the sequence + and seqlen == 1 + and self.inference_mode != StateUpdateKernel.cs + ): + # we go in here for standard 1 token per time-step inference. + # seqlen_offset > 0 means we are in the middle of a sequence + # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward( + xBC, padded_len, mask=inference_params.mask if inference_params is not None else None + ) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + # TODO: this kernel needs to be aupdated to use the mask! If used solely for throughout benchmarking, it is enough to call it as is. + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + return outputs + + def step(self, u, state, **kwargs): + """ + Args: + u: (B, D), + state: dict. + + Returns: + out: (B, D), + state: dict. + + """ + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate memory for inference cache.""" + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + Get states from cache. + + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len, mask=None): + """Convolutional layer forward pass for the full sequence.""" + seqlen = xBC.shape[1] + mask_seql = -1 if mask is None else mask.shape[1] + # If seqlen != mask_seql, this likely means we preallocated mask for static generation, + # but here we are in the prefill phase. + # Note, mask is needed to prevent state upodate for no-op tokens as described in https://proceedings.mlr.press/v262/wu24a.html + # Note, if we want to use joint attanimnet and advancement in selective-scan mode, we would need to implement masking into the kernel of causal_conv1d_fn and mamba_chunk_scan_combined + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + else: + # TODO: note, this only works for chunked inference, for autoregressive mode we need to update the kernel to make sure conv state is not poluted + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + return xBC + + def convolutional_step(self, xBC, conv_state): + """Convolutional layer forward pass for a single step.""" + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state diff --git a/fast_llm/models/ssm/external/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/modeling_mtp_llamba.py new file mode 100644 index 000000000..6d9746db1 --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_mtp_llamba.py @@ -0,0 +1,389 @@ +# Copyright (c) 2024, Kevin Li, Aviv Bick. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from mamba_ssm.utils.generation import GenerationMixin +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.utils.generic import ModelOutput + +from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig +from .discrete_mamba2 import DiscreteMamba2 + + +class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" + + def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + """Set the extra representation of the module.""" + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LlamaMLP(nn.Module): + """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" + + def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) + self.act_fn = ACT2FN[act_fn] + + def forward(self, x): + """ + Args: + x: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): + """MambaLM model with a language modeling head on top (linear layer).""" + + def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: + super().__init__() + + # Load config + if not isinstance(config, LlambaConfig): + config = LlambaConfig(**config) + self.config = config + + # Factory kwargs + factory_kwargs = {"device": device, "dtype": dtype} + + # Pad vocab size to be a multiple of pad_vocab_size_multiple + vocab_size = config.vocab_size + pad_vocab_size_multiple = config.pad_vocab_size_multiple + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.config.vocab_size = vocab_size + + # Mixer model + self.backbone = MixerModel( + input_size=vocab_size, + config=self.config, + initializer_cfg=initializer_cfg, + **factory_kwargs, + ) + + # MTP heads + self.mtp_heads = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=layer_idx, + ).to(device) + for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) + ] + ) + + self.mtp_norms = nn.ModuleList( + [ + LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) + for _ in range(config.prediction_heads - 1) + ] + ) + # LM head + if not self.config.tie_embeddings: + self.lm_head = nn.Linear( + in_features=self.config.d_model, + out_features=self.config.vocab_size, + bias=self.config.lm_head_bias, + **factory_kwargs, + ) + else: + self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + + mtps = { + i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) + for i, layer in enumerate(self.mtp_heads) + } + return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} + + def forward( + self, + input_ids, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + ): + """ + Args: + input_ids: torch.Tensor of shape (batch_size, seq_len), + position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), + return_hidden_states: bool, optional, + return_logits: bool, optional, whether to compute the logits with the LM head, + inference_params: dict, optional, the model's inference cache, + num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. + + Returns: + CustomMambaCausalLMOutput. + + """ + outputs = self.backbone( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + # MTP heads processing + latents = [] + hidden_states = outputs["last_hidden_state"] + hidden_states_before_last = outputs["hidden_state_before_last"] + + # last layer already has layer norm applied + latents.append(hidden_states) + + # Process through MTP heads + for i, mtp_head in enumerate(self.mtp_heads): + mtp_outputs = mtp_head( + hidden_states_before_last, + inference_params=inference_params, + position_ids=position_ids, + ) + mtp_hidden_states = mtp_outputs["hidden_states"] + latents.append(self.mtp_norms[i](mtp_hidden_states)) + + # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) + stacked_latents = torch.stack(latents, dim=-2) + + if return_logits: + if isinstance(self.lm_head, nn.Linear): + # Apply lm_head to each prediction head's output + logits = self.lm_head(stacked_latents).float() + else: + # Using the tied embedding weights + logits = self.lm_head(stacked_latents) + + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=stacked_latents, + ) + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Save the model's state_dict + model_path = os.path.join(save_directory, "pytorch_model.bin") + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, "config.json") + with open(config_path, "w") as f: + json.dump(self.config.to_dict(), f) + + +class MixerModel(nn.Module): + """Mixer model with a stack of Mixer layers.""" + + def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config = config + self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) + + self.layers = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=i, + ).to(device) + for i in range(self.config.n_layer) + ] + ) + + self.final_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, + eps=self.config.norm_epsilon, + factory_kwargs=factory_kwargs, + ) + + return + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + def forward( + self, + input_ids, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ): + """Run the model.""" + # Start running the layers + hidden_states = self.embedding(input_ids) + + # Initialize outputs + outputs = { + "last_hidden_state": None, + "hidden_state_before_last": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + # Run the layers + for layer in self.layers: + layer_outputs = layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + if layer == self.layers[-1]: + outputs["hidden_state_before_last"] = hidden_states + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + # Last layer, apply layer norm + outputs["last_hidden_state"] = self.final_layernorm(hidden_states) + return outputs + + +class Block(nn.Module): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + + def __init__(self, config, factory_kwargs, layer_idx, **kwargs): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + # Mixer + self.mixer = DiscreteMamba2( + d_model=self.config.d_model, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + # Other components + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.post_attention_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + ) + self.mlp = LlamaMLP( + hidden_size=self.config.d_model, + **config.mlp_cfg, + factory_kwargs=factory_kwargs, + ) + + def forward( + self, + hidden_states: Tensor, + inference_params=None, + **kwargs, + ): + """ + Pass the input through the encoder layer. + + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + inference_params: dict, optional, + + Returns: + dict with keys: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). + """ + outputs = {} + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Apply Mixer + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) From 87b319769cbdf5a1130edc41c758b4c981fd5847 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 01:23:20 +0000 Subject: [PATCH 032/122] modeling apriel ssm --- .../ssm/external/configuration_ssm_apriel.py | 101 +++ .../ssm/external/modeling_ssm_apriel.py | 730 ++++++++++++++++++ 2 files changed, 831 insertions(+) create mode 100644 fast_llm/models/ssm/external/configuration_ssm_apriel.py create mode 100644 fast_llm/models/ssm/external/modeling_ssm_apriel.py diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_apriel.py new file mode 100644 index 000000000..0c75ca658 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_ssm_apriel.py @@ -0,0 +1,101 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apriel SSM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + pass + + +class AprielSSMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + .... + ```""" + + model_type = "apriel_ssm" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mlp_bias=False, + rms_norm_eps=1e-5, + ssm_cfg: dict = None, + **kwargs, + ): + self.vocab_size = vocab_size + # self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + # self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + # self.rope_theta = rope_theta + self.mlp_bias = mlp_bias + # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 4104, # to make sure we have 24 heads + } + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py new file mode 100644 index 000000000..d30d5b66c --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -0,0 +1,730 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from .configuration_ssm_apriel import AprielSSMConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + out, _ = self.step(u, state) + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + ) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] + Args: + config: AprielSSMConfig + """ + + def __init__(self, config: AprielSSMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ) -> Union[tuple, BaseModelOutputWithPast]: + + hidden_states = self.embed_tokens(input_ids) + + # decoder layers + outputs = { + "last_hidden_state": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + + layer_outputs = decoder_layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + outputs["last_hidden_state"] = self.norm(hidden_states) + return outputs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = AprielSSMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] From d3e1df246279e367cbd45bbf4b5492165b6e4af5 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 12:25:05 +0000 Subject: [PATCH 033/122] Apriel to SSM --- .../models/ssm/external/ariel_to_ssm.ipynb | 447 ++++++++++++++++++ 1 file changed, 447 insertions(+) create mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb new file mode 100644 index 000000000..8c5f64ae7 --- /dev/null +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.90it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.bfloat16" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_model.config.torch_dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.83207168" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_params/1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " rms_norm_eps=config.rms_norm_eps)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMConfig {\n", + " \"_attn_implementation_autoset\": true,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel_ssm\",\n", + " \"num_hidden_layers\": 28,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"d_inner\": 4104,\n", + " \"d_state\": 64,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 24,\n", + " \"n_v_heads\": 24\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm_config" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N params SSM: 5.660780512\n" + ] + } + ], + "source": [ + "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load State dict into SSM" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/ariel_ssm\")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "24" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.model.layers[0].mixer.n_v_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=12320, bias=False)\n", + " (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Try a forward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4688, -1.6641, 0.4609, ..., -7.1562, -3.7812, -5.9062],\n", + " [-3.5000, 1.4297, 4.3125, ..., -5.3438, -4.9375, -2.9844],\n", + " [-3.1094, 0.7930, 2.2969, ..., -3.1250, -4.1875, -2.1250],\n", + " ...,\n", + " [-5.3438, -3.0938, -3.9062, ..., -4.9062, -3.0000, -3.9688],\n", + " [-3.0625, -3.2188, 5.6562, ..., -2.7812, -2.5938, -6.6562],\n", + " [-1.8438, -1.7500, 5.9062, ..., -3.7188, -2.1250, -0.8281]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[ 1.2266, 0.5547, -1.1953, ..., 0.1089, -2.5781, 0.6328],\n", + " [-0.4395, 0.5938, -0.1562, ..., -0.6719, -0.6367, -0.3086],\n", + " [ 0.0077, 0.6680, -1.0703, ..., -3.6875, 0.2207, 0.1299],\n", + " ...,\n", + " [-0.0703, 0.4551, 0.1104, ..., 1.3438, 1.3984, 1.1641],\n", + " [-0.0613, 1.9141, -0.5430, ..., -1.0312, -0.6680, 0.0518],\n", + " [-0.6172, 0.2148, -0.5977, ..., -1.2734, -0.1914, 2.2344]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hymba2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 082cf22c941a2d8ea992d8f088f27fc57c92c4de Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 13:03:23 +0000 Subject: [PATCH 034/122] Apriel SSM conversion --- fast_llm/layers/ssm/config.py | 10 +- fast_llm/models/ssm/config.py | 24 +- fast_llm/models/ssm/conversion.py | 298 ++++++++++++++---- .../models/ssm/external/ariel_to_ssm.ipynb | 115 ++++++- 4 files changed, 374 insertions(+), 73 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 984858fcc..2effa8a6e 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -55,7 +55,7 @@ class SSMArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, ) - dt_rank: int = Field( + dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.core, @@ -85,12 +85,16 @@ class SSMArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, ) + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu - if self.dt_rank is None: - self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation super()._validate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index b38467d32..9d8c9bfd0 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -49,12 +49,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" ) - if self.ssm.dt_rank < 0: + if self.ssm.dt_rank is None: mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) else: mamba_dt_rank = self.ssm.dt_rank - d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + d_inner = ( + int(self.ssm.expansion_factor * self.transformer.hidden_size) + if self.ssm.d_inner is None + else self.ssm.d_inner + ) # Hidden dimension tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) # Mamba-specific dimensions @@ -115,12 +119,26 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler +class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler + + return AprielSSMHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( + LLambaHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + ) @classmethod def get_model_class(cls) -> type["HybridSSMModel"]: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 190b2ffae..a8b6ceff3 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -5,6 +5,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConstantExportParamConverter, ConstantImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -18,7 +19,11 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType from fast_llm.models.gpt.conversion import MLPLayer2Converter -from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielSSMHuggingfaceCheckpointFormat, + HybridSSMModelConfig, + LLambaHuggingfaceCheckpointFormat, +) from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -26,74 +31,17 @@ pass -class LLambaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - # TODO: is there an equivalen of pad_vocab_size_multiple in FastLLM, does it matter? - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - RenameParamConverter( - fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), RenameParamConverter( fast_llm_names=(("vocab_size",),), export_names=(("vocab_size",),), ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( fast_llm_names=(("ssm", "state_size"),), export_names=( @@ -161,6 +109,238 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] + +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + """ + Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json + """ + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("n_layer",),), + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=( + ( + "mlp_cfg", + "act_fn", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), + export_names=( + ( + "mlp_cfg", + "bias", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=( + ( + "mlp_cfg", + "intermediate_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + ) + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"backbone.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"backbone.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + ] + def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 8c5f64ae7..85608075a 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -48,7 +48,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.90it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.68it/s]\n" ] }, { @@ -246,7 +246,52 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", + " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -255,7 +300,7 @@ "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" ] }, - "execution_count": 49, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -266,14 +311,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", + " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "\n", "apriel_ssm.to(device).to(dtype=torch.bfloat16)" ] }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# apriel_ssm.state_dict()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -283,11 +372,21 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:2714: UserWarning: `save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead.\n", + " warnings.warn(\n" + ] + } + ], "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/ariel_ssm\")" + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm\",\n", + " save_config=True)\n" ] }, { From 0d4d5c5b0cae2503a11b905e832d0693fa643c89 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Apr 2025 13:14:00 -0400 Subject: [PATCH 035/122] fix --- fast_llm/layers/language_model/config.py | 2 ++ fast_llm/models/gpt/model.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b4b4e187c..4fb471fb3 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,6 +215,8 @@ def _validate(self) -> None: super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) + if self.prediction_heads > 1: + Assert.gt(self.transformer.num_layers, 1) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 54c3d8829..9e28373b3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -102,6 +102,9 @@ def get_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, layer_index=i + 1, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], From c43e535ce8a296ed95e468b370e45d6165d5beb3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 18:52:28 +0000 Subject: [PATCH 036/122] wip --- fast_llm/models/ssm/config.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9d8c9bfd0..7cad0d529 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -10,7 +10,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -169,9 +169,36 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + reference_models: dict[str, PretrainedGPTModelConfig] = ( + FieldUpdate() + ) # TODO: make sure any reference mdoel can be suported @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: from fast_llm.models.ssm.trainer import SSMTrainer return SSMTrainer + + def _validate(self) -> None: + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.distillation_model is not None: + # TODO: Support loss masking for distillation? + assert not self.batch.use_loss_masking_spans + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner From a1f44d41d47c4424b2b033ae5515d7a067d110c7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 21:31:43 +0000 Subject: [PATCH 037/122] conversion apriel ssm --- fast_llm/models/gpt/config.py | 2 ++ fast_llm/models/ssm/conversion.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 988f27b89..b82dd3e85 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -226,4 +226,6 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: from fast_llm.models.gpt.model import GPTInferenceRunner + # TODO" we dont have inference runner for SSM/Hybrid yet, should return None? + return GPTInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index a8b6ceff3..fb7776d51 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -349,58 +349,58 @@ def _create_weight_converters(self) -> list[WeightConverter]: # Embedding and output if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias ) for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", + f"model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", + f"model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # Norm converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias ) # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") return converters From fbec02df596e91d013ea814d4cd00b4c346d2849 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 21:32:21 +0000 Subject: [PATCH 038/122] config apriel --- fast_llm/models/ssm/external/configuration_ssm_apriel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_apriel.py index 0c75ca658..2e5d5810c 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/configuration_ssm_apriel.py @@ -56,6 +56,7 @@ def __init__( mlp_bias=False, rms_norm_eps=1e-5, ssm_cfg: dict = None, + head_dim: int = 128, **kwargs, ): self.vocab_size = vocab_size @@ -71,8 +72,7 @@ def __init__( self.use_cache = use_cache # self.rope_theta = rope_theta self.mlp_bias = mlp_bias - # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - + self.head_dim = head_dim # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. # if self.rope_scaling is not None and "type" in self.rope_scaling: @@ -94,8 +94,9 @@ def __init__( "chunk_size": 128, "activation": "identity", "bias": False, - "d_inner": 4104, # to make sure we have 24 heads + "d_inner": 24 * self.head_dim, # num_heads * head_dim } + assert self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"] __all__ = ["AprielConfig"] From 75d646059fcaac0762aaedae1c8a175ea2a67e37 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 23:53:19 +0000 Subject: [PATCH 039/122] temp checkpoint conversion --- fast_llm/models/ssm/conversion.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb7776d51..a9a139b55 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -1,3 +1,4 @@ +import dataclasses import json import os import pathlib @@ -31,6 +32,23 @@ pass +@dataclasses.dataclass(kw_only=True) +class HybridBlockLayoutConverter(ParamConverter): + num_layers_getter: typing.Callable[[typing.Any], int] = lambda config: config.transformer.num_layers + + # TODO: generalize this t + def __post_init__(self) -> None: + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + # Use the expanded list as-is + return (["m2"],) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + return (["m2"],) + + class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig @@ -170,6 +188,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_embeddings",),), ), + HybridBlockLayoutConverter( + fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) + ), ] def _create_weight_converters(self) -> list[WeightConverter]: @@ -338,7 +359,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_word_embeddings",),), ), - ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + # ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + HybridBlockLayoutConverter( + fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) + ), ] def _create_weight_converters(self) -> list[WeightConverter]: From 73a4252e2f93b5ae5cdb36a9a026bd81db440a01 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 13:33:04 +0000 Subject: [PATCH 040/122] block pattern for hybrid conversion --- fast_llm/models/ssm/conversion.py | 56 +++++++++++++++++++------------ 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index a9a139b55..fb862a023 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -1,4 +1,3 @@ -import dataclasses import json import os import pathlib @@ -32,21 +31,39 @@ pass -@dataclasses.dataclass(kw_only=True) -class HybridBlockLayoutConverter(ParamConverter): - num_layers_getter: typing.Callable[[typing.Any], int] = lambda config: config.transformer.num_layers +class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + """ + This is a temporary solution for importing/exporting hybrid models. Since there is no standard solution for this in HF, we just use the block_pattern. + If block_pattern is None, it will multiply the provided default block type by the number of layers and export/import it. + If block_pattern is provided, it will export/import it as-is. + """ - # TODO: generalize this t - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _default_block_type: str = "m2" - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - # Use the expanded list as-is - return (["m2"],) + @classmethod + def _import_config(cls, config, architecture_only: bool = False): + cls.num_layers = config["n_layer"] if "n_layer" in config else config["num_hidden_layers"] + cls.block_pattern = config.get("hybrid_block_layout", None) + return super()._import_config(config, architecture_only) - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (["m2"],) + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + if cls.block_pattern is not None: + block_converter = MappedConfigParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + fast_llm_value=cls.block_pattern, + export_value=cls.block_pattern, + ) + else: + block_converter = ConstantImportParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + fast_llm_value=[cls._default_block_type] * cls.num_layers, + ) + + return super()._create_config_converters() + [block_converter] class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @@ -128,10 +145,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] -class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): +class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + _default_block_type: str = "m2" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -188,9 +206,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_embeddings",),), ), - HybridBlockLayoutConverter( - fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) - ), ] def _create_weight_converters(self) -> list[WeightConverter]: @@ -316,9 +331,10 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): +class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + _default_block_type: str = "m2" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -359,10 +375,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_word_embeddings",),), ), - # ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), - HybridBlockLayoutConverter( - fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) - ), ] def _create_weight_converters(self) -> list[WeightConverter]: From 5afc7dc21ead590a67c7bc4fd2765f4f95456205 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 14:18:55 +0000 Subject: [PATCH 041/122] SSMBlockType --- fast_llm/layers/ssm/config.py | 13 +++++++++++++ fast_llm/models/ssm/config.py | 19 +++++++++++-------- fast_llm/models/ssm/conversion.py | 7 ++++--- fast_llm/models/ssm/model.py | 10 +++++----- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2effa8a6e..459401df1 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,3 +1,5 @@ +import enum + from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.functional.config import ActivationType @@ -20,6 +22,17 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads +class SSMBlockType(str, enum.Enum): + """ + An enum for the available mamba types for the MLP layer. + """ + + mamba = "m" + mamba2_discrete = "m2d" + mamba2 = "m2" + transformer = "t" + + @config_class() class SSMArchitectureConfig(BaseModelArchitectureConfig): _abstract = False diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7cad0d529..d77d206b0 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMDimNames +from fast_llm.layers.ssm.config import SSMBlockType, SSMDimNames from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert @@ -24,8 +24,8 @@ class HybridSSMArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False hybrid_block_layout: list[str] = Field( - default_factory=lambda: ["m2"], - desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.", + default_factory=lambda: [SSMBlockType.mamba2_discrete.value], + desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) @@ -44,9 +44,12 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - if not "m2" in self.hybrid_block_layout and not "m" in self.hybrid_block_layout: + if ( + not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout + and not SSMBlockType.mamba.value in self.hybrid_block_layout + ): raise ValueError( - "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" + f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" ) if self.ssm.dt_rank is None: @@ -69,7 +72,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - if "m2" in self.hybrid_block_layout: + if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: # Mamba2 specific dimensions # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 headdim = d_inner // self.ssm.n_v_heads @@ -101,8 +104,8 @@ def _validate(self): Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) Assert.custom( - lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", + lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", ) super()._validate() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb862a023..c2e54ca09 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -18,6 +18,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.models.gpt.conversion import MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHuggingfaceCheckpointFormat, @@ -40,7 +41,7 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _import_config(cls, config, architecture_only: bool = False): @@ -149,7 +150,7 @@ class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSM _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -334,7 +335,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 33d2c185c..c9d1ba7d8 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -11,7 +11,7 @@ from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == "t": + if block_type == SSMBlockType.transformer.value: # Transformer block layers.append( TransformerLayer( @@ -52,7 +52,7 @@ def get_layers(self) -> list[Layer]: layer_index=i + 1, ) ) - elif block_type == "m2": + elif block_type == SSMBlockType.mamba2_discrete.value: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -62,7 +62,7 @@ def get_layers(self) -> list[Layer]: ) layers.append(mamba_block) - elif block_type == "m": + elif block_type == SSMBlockType.mamba.value: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, @@ -74,7 +74,7 @@ def get_layers(self) -> list[Layer]: layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") # Add the language model head layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)) From 8e9facf9584ff5e8a72cecdcbc8c0b0d6fcdaab8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 19:02:48 +0000 Subject: [PATCH 042/122] wip --- fast_llm/layers/ssm/mamba2.py | 354 +++ .../models/ssm/external/ariel_to_ssm.ipynb | 2240 ++++++++++++++++- 2 files changed, 2507 insertions(+), 87 deletions(-) create mode 100644 fast_llm/layers/ssm/mamba2.py diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py new file mode 100644 index 000000000..5763cb92c --- /dev/null +++ b/fast_llm/layers/ssm/mamba2.py @@ -0,0 +1,354 @@ +""" +This code is adapted from https://github.com/jxiw/MambaInLlama/blob/main/mamba2/hybrid_mamba_layer.py +""" + +import math + +import causal_conv1d +import einops +import mamba_ssm.ops.triton.ssd_combined +import torch +from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.tensor import kaiming_init_ + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Mamba2(torch.nn.Module): + def __init__( + self, + config: SSMConfig, + layer_idx: int, + tensor_space: TensorSpace, + ): + # factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config: SSMConfig = config + bias = config.add_bias_linear + self.layer_idx = layer_idx + + td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + tensor_space.get_tensor_dim(SSMDimNames.state_dim) + td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + tensor_space.get_tensor_dim(SSMDimNames.qk_heads) + tensor_space.get_tensor_dim(SSMDimNames.v_heads) + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) + + # self.d_model = d_model + # self.d_state = d_state + # self.d_conv = d_conv + # self.conv_init = conv_init + # self.expand = expand + # self.process_group = process_group + # self.sequence_parallel = sequence_parallel + # self.world_size = 1 if process_group is None else process_group.size() + # self.local_rank = 0 if process_group is None else process_group.rank() + # self.d_inner = d_inner if d_inner is not None else (self.expand * self.d_model) // self.world_size + # # assert self.d_inner * self.world_size == self.expand * self.d_model + # self.headdim = headdim + # self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size + # assert ngroups % self.world_size == 0 + # self.ngroups = ngroups // self.world_size + # assert self.d_ssm % self.headdim == 0 + # self.nheads = self.d_ssm // self.headdim + # self.D_has_hdim = D_has_hdim + # self.rmsnorm = rmsnorm + # self.norm_before_gate = norm_before_gate + # self.dt_limit = dt_limit + # self.activation = "silu" + # self.chunk_size = chunk_size + # self.use_mem_eff_path = use_mem_eff_path + # self.layer_idx = layer_idx + # self.d_xb = d_xb + # self.repeat_group = self.d_inner // self.d_xb + # self.repeat_kv_before_conv = repeat_kv_before_conv + + assert self.d_inner == self.ngroups * self.d_state + assert self.d_inner == self.d_ssm + + self.nheads = self.ngroups + self.headdim = self.d_state + + # Order: [z, x, B, C, dt] + # [hidden_dim, hidden_dim, d_state] + d_in_proj = self.d_inner + self.d_xb + self.d_xb + self.d_inner + self.nheads + # d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + if self.process_group is None: + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + else: + self.in_proj = ColumnParallelLinear( + self.d_model, + d_in_proj * self.world_size, + bias=bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + # conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state + + if self.repeat_kv_before_conv: + conv_dim = self.d_inner + self.d_inner + self.d_inner + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + else: + conv_dim = self.d_inner + self.d_xb + self.d_xb + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=dtype) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) + self.D._no_weight_decay = True + + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated( + self.d_ssm, + eps=1e-5, + norm_before_gate=self.norm_before_gate, + group_size=self.d_ssm // ngroups, + **factory_kwargs, + ) + + # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.out_proj = Linear( + td_inner, + td_model, + bias=bias, + weight_init_method=kaiming_init_(td_inner.size), + ) + + def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): + """ + u: (batch, seqlen, hidden_dim) if seqlen=None. + If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we + split u during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + Returns: same shape as u + """ + seqlen_og = seqlen + if seqlen is None: + batch, seqlen, dim = u.shape + else: + batch_seqlen, dim = u.shape + batch = batch_seqlen // seqlen + + conv_state, ssm_state = None, None + if inference_params is not None: + inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch + conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(u, conv_state, ssm_state) + return out + + zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) + if seqlen_og is not None: + zxbcdt = einops.rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) + # If the model is loaded in fp16, without the .float() here, A might be -inf + A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) + dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + + # [z, x, B, C, dt] + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_inner - 2 * self.d_xb - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.d_xb, self.nheads], dim=-1 + ) + + if self.repeat_kv_before_conv: + x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) + # minic the GQA + x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + # x shape: (bsz, n_group, l, dim) + B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + B = repeat_kv(B, self.repeat_group) + # combine x, B, C + x = einops.rearrange(x, "b g l p -> b l (g p)") + B = einops.rearrange(B, "b g l p -> b l (g p)") + xBC = torch.cat((x, B, C), dim=-1) + + if conv_state is not None: + if cu_seqlens is None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = einops.rearrange(xBC, "b l d -> b d l") + conv_state.copy_( + torch.nn.functional.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)) + ) # Update state (B D W) + else: + assert ( + causal_conv1d.causal_conv1d_varlen_states is not None + ), "varlen inference requires causal_conv1d package" + assert batch == 1, "varlen inference only supports batch dimension 1" + conv_varlen_states = causal_conv1d.causal_conv1d_varlen_states( + xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + ) + conv_state.copy_(conv_varlen_states) + assert self.activation in ["silu", "swish"] + + if causal_conv1d.causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" + xBC = self.act( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1) :] + ) # (B, L, self.d_ssm + 2 * ngroups * d_state) + else: + xBC = causal_conv1d.causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=seq_idx, + ).transpose(1, 2) + + if self.repeat_kv_before_conv: + x, B, C = torch.split( + xBC, [self.ngroups * self.d_state, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1 + ) + + y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + A, + einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + return_varlen_states=cu_seqlens is not None and inference_params is not None, + ) + + else: + # self.d_xb + self.d_xb + self.d_inner + x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) + + # minic the GQA + x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + # x shape: (bsz, n_group, l, dim) + + B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + B = repeat_kv(B, self.repeat_group) + + y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + # einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + einops.rearrange(x, "b g l p -> b l g p"), + dt, + A, + # einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + einops.rearrange(B, "b g l n -> b l g n"), + einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + return_varlen_states=cu_seqlens is not None and inference_params is not None, + ) + + if ssm_state is not None: + y, last_state, *rest = y + if cu_seqlens is None: + ssm_state.copy_(last_state) + else: + varlen_states = rest[0] + ssm_state.copy_(varlen_states) + y = einops.rearrange(y, "b l h p -> b l (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([torch.nn.functional.silu(z0) * x0, y], dim=-1) + if seqlen_og is not None: + y = einops.rearrange(y, "b l d -> (b l) d") + out = self.out_proj(y) + return out + + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ).transpose(1, 2) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 85608075a..a8390fa3d 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -41,14 +41,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.68it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" ] }, { @@ -82,7 +82,7 @@ ")" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -106,7 +106,7 @@ "torch.bfloat16" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -135,7 +135,7 @@ "4.83207168" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -146,7 +146,58 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" + ] + } + ], + "source": [ + "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'n_qk_heads'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" + ] + } + ], + "source": [ + "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", + "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -162,12 +213,13 @@ " pad_token_id=config.pad_token_id,\n", " bos_token_id=config.bos_token_id,\n", " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", " rms_norm_eps=config.rms_norm_eps)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -176,60 +228,1984 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AprielSSMConfig {\n", - " \"_attn_implementation_autoset\": true,\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel_ssm\",\n", - " \"num_hidden_layers\": 28,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"ssm_cfg\": {\n", - " \"activation\": \"identity\",\n", - " \"bias\": false,\n", - " \"chunk_size\": 128,\n", - " \"d_inner\": 4104,\n", - " \"d_state\": 64,\n", - " \"expand\": 1,\n", - " \"n_qk_heads\": 24,\n", - " \"n_v_heads\": 24\n", - " },\n", - " \"tie_word_embeddings\": false,\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" + "OrderedDict([('model.embed_tokens.weight',\n", + " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", + " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", + " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", + " ...,\n", + " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", + " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", + " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", + " ('model.layers.0.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.0.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.0.mixer.in_proj.weight',\n", + " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", + " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", + " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", + " ...,\n", + " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", + " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", + " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", + " ('model.layers.0.mixer.conv1d.weight',\n", + " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", + " \n", + " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", + " \n", + " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", + " \n", + " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", + " \n", + " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", + " ('model.layers.0.mixer.conv1d.bias',\n", + " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", + " ('model.layers.0.mixer.out_proj.weight',\n", + " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", + " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", + " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", + " ...,\n", + " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", + " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", + " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", + " ('model.layers.0.mlp.gate_proj.weight',\n", + " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", + " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", + " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", + " ...,\n", + " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", + " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", + " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", + " ('model.layers.0.mlp.up_proj.weight',\n", + " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", + " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", + " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", + " ...,\n", + " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", + " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", + " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", + " ('model.layers.0.mlp.down_proj.weight',\n", + " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", + " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", + " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", + " ...,\n", + " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", + " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", + " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", + " ('model.layers.0.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.0.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.1.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.in_proj.weight',\n", + " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", + " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", + " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", + " ...,\n", + " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", + " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", + " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", + " ('model.layers.1.mixer.conv1d.weight',\n", + " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", + " \n", + " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", + " \n", + " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", + " \n", + " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", + " \n", + " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", + " ('model.layers.1.mixer.conv1d.bias',\n", + " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", + " ('model.layers.1.mixer.out_proj.weight',\n", + " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", + " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", + " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", + " ...,\n", + " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", + " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", + " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", + " ('model.layers.1.mlp.gate_proj.weight',\n", + " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", + " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", + " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", + " ...,\n", + " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", + " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", + " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", + " ('model.layers.1.mlp.up_proj.weight',\n", + " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", + " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", + " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", + " ...,\n", + " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", + " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", + " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", + " ('model.layers.1.mlp.down_proj.weight',\n", + " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", + " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", + " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", + " ...,\n", + " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", + " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", + " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", + " ('model.layers.1.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.2.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.in_proj.weight',\n", + " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", + " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", + " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", + " ...,\n", + " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", + " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", + " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", + " ('model.layers.2.mixer.conv1d.weight',\n", + " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", + " \n", + " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", + " \n", + " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", + " \n", + " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", + " \n", + " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", + " ('model.layers.2.mixer.conv1d.bias',\n", + " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", + " ('model.layers.2.mixer.out_proj.weight',\n", + " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", + " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", + " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", + " ...,\n", + " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", + " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", + " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", + " ('model.layers.2.mlp.gate_proj.weight',\n", + " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", + " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", + " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", + " ...,\n", + " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", + " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", + " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", + " ('model.layers.2.mlp.up_proj.weight',\n", + " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", + " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", + " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", + " ...,\n", + " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", + " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", + " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", + " ('model.layers.2.mlp.down_proj.weight',\n", + " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", + " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", + " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", + " ...,\n", + " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", + " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", + " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", + " ('model.layers.2.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.3.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.in_proj.weight',\n", + " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", + " -2.1648e-03, 2.8147e-03],\n", + " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", + " 4.9989e-03, 6.4200e-03],\n", + " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", + " 9.7000e-03, 3.1136e-05],\n", + " ...,\n", + " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", + " 4.6152e-04, 5.5341e-02],\n", + " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", + " -1.0645e-02, -1.8442e-03],\n", + " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", + " -4.7119e-03, 3.2352e-02]])),\n", + " ('model.layers.3.mixer.conv1d.weight',\n", + " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", + " \n", + " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", + " \n", + " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", + " \n", + " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", + " \n", + " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", + " ('model.layers.3.mixer.conv1d.bias',\n", + " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", + " ('model.layers.3.mixer.out_proj.weight',\n", + " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", + " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", + " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", + " ...,\n", + " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", + " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", + " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", + " ('model.layers.3.mlp.gate_proj.weight',\n", + " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", + " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", + " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", + " ...,\n", + " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", + " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", + " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", + " ('model.layers.3.mlp.up_proj.weight',\n", + " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", + " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", + " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", + " ...,\n", + " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", + " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", + " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", + " ('model.layers.3.mlp.down_proj.weight',\n", + " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", + " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", + " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", + " ...,\n", + " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", + " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", + " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", + " ('model.layers.3.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.4.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.in_proj.weight',\n", + " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", + " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", + " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", + " ...,\n", + " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", + " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", + " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", + " ('model.layers.4.mixer.conv1d.weight',\n", + " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", + " \n", + " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", + " \n", + " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", + " \n", + " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", + " \n", + " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", + " ('model.layers.4.mixer.conv1d.bias',\n", + " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", + " ('model.layers.4.mixer.out_proj.weight',\n", + " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", + " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", + " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", + " ...,\n", + " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", + " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", + " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", + " ('model.layers.4.mlp.gate_proj.weight',\n", + " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", + " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", + " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", + " ...,\n", + " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", + " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", + " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", + " ('model.layers.4.mlp.up_proj.weight',\n", + " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", + " 3.6365e-02, -8.1017e-03],\n", + " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", + " -3.1505e-02, 2.1047e-03],\n", + " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", + " -2.7856e-04, -8.7440e-04],\n", + " ...,\n", + " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", + " -1.9620e-02, 2.5112e-02],\n", + " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", + " -1.6838e-03, 2.0189e-02],\n", + " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", + " -1.3502e-02, -1.8923e-02]])),\n", + " ('model.layers.4.mlp.down_proj.weight',\n", + " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", + " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", + " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", + " ...,\n", + " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", + " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", + " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", + " ('model.layers.4.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.5.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.in_proj.weight',\n", + " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", + " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", + " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", + " ...,\n", + " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", + " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", + " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", + " ('model.layers.5.mixer.conv1d.weight',\n", + " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", + " \n", + " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", + " \n", + " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", + " \n", + " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", + " \n", + " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", + " ('model.layers.5.mixer.conv1d.bias',\n", + " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", + " ('model.layers.5.mixer.out_proj.weight',\n", + " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", + " 7.0706e-03, -4.7970e-03],\n", + " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", + " 2.2802e-02, -2.9729e-03],\n", + " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", + " -4.2582e-03, 2.4007e-02],\n", + " ...,\n", + " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", + " -2.0391e-02, 5.3912e-03],\n", + " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", + " -2.0596e-02, -8.2915e-03],\n", + " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", + " -3.8566e-02, 2.7005e-02]])),\n", + " ('model.layers.5.mlp.gate_proj.weight',\n", + " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", + " 3.6159e-02, -1.7253e-02],\n", + " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", + " 4.4487e-03, -9.6723e-03],\n", + " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", + " 1.2872e-02, -1.0443e-02],\n", + " ...,\n", + " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", + " 2.8752e-02, -8.0048e-03],\n", + " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", + " -3.2179e-02, -1.6437e-02],\n", + " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", + " 7.7639e-03, -2.2494e-03]])),\n", + " ('model.layers.5.mlp.up_proj.weight',\n", + " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", + " -1.8538e-02, 6.7395e-03],\n", + " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", + " -3.0374e-02, 2.5795e-02],\n", + " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", + " 1.4115e-02, -2.6438e-02],\n", + " ...,\n", + " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", + " -4.9858e-02, -1.1575e-02],\n", + " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", + " 1.8568e-02, 5.2615e-03],\n", + " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", + " 1.2643e-02, -1.1568e-05]])),\n", + " ('model.layers.5.mlp.down_proj.weight',\n", + " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", + " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", + " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", + " ...,\n", + " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", + " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", + " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", + " ('model.layers.5.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.6.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.in_proj.weight',\n", + " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", + " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", + " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", + " ...,\n", + " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", + " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", + " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", + " ('model.layers.6.mixer.conv1d.weight',\n", + " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", + " \n", + " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", + " \n", + " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", + " \n", + " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", + " \n", + " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", + " ('model.layers.6.mixer.conv1d.bias',\n", + " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", + " ('model.layers.6.mixer.out_proj.weight',\n", + " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", + " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", + " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", + " ...,\n", + " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", + " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", + " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", + " ('model.layers.6.mlp.gate_proj.weight',\n", + " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", + " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", + " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", + " ...,\n", + " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", + " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", + " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", + " ('model.layers.6.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", + " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", + " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", + " ...,\n", + " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", + " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", + " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", + " ('model.layers.6.mlp.down_proj.weight',\n", + " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", + " 4.8023e-03, 9.8474e-03],\n", + " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", + " -9.5275e-04, -1.8121e-02],\n", + " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", + " 2.8131e-03, 5.8219e-03],\n", + " ...,\n", + " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", + " -8.3965e-03, 2.0074e-02],\n", + " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", + " 3.4947e-02, -2.2158e-02],\n", + " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", + " 4.0735e-05, 3.3651e-02]])),\n", + " ('model.layers.6.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.7.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.in_proj.weight',\n", + " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", + " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", + " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", + " ...,\n", + " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", + " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", + " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", + " ('model.layers.7.mixer.conv1d.weight',\n", + " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", + " \n", + " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", + " \n", + " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", + " \n", + " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", + " \n", + " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", + " ('model.layers.7.mixer.conv1d.bias',\n", + " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", + " ('model.layers.7.mixer.out_proj.weight',\n", + " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", + " 1.4225e-02, 2.8166e-02],\n", + " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", + " 3.7454e-03, 1.7579e-02],\n", + " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", + " 1.4766e-02, -1.8002e-02],\n", + " ...,\n", + " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", + " -1.3072e-02, -1.0412e-02],\n", + " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", + " -1.1047e-02, 1.5815e-02],\n", + " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", + " -2.1834e-02, -2.1826e-02]])),\n", + " ('model.layers.7.mlp.gate_proj.weight',\n", + " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", + " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", + " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", + " ...,\n", + " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", + " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", + " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", + " ('model.layers.7.mlp.up_proj.weight',\n", + " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", + " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", + " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", + " ...,\n", + " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", + " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", + " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", + " ('model.layers.7.mlp.down_proj.weight',\n", + " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", + " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", + " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", + " ...,\n", + " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", + " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", + " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", + " ('model.layers.7.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.8.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.in_proj.weight',\n", + " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", + " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", + " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", + " ...,\n", + " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", + " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", + " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", + " ('model.layers.8.mixer.conv1d.weight',\n", + " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", + " \n", + " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", + " \n", + " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", + " \n", + " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", + " \n", + " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", + " ('model.layers.8.mixer.conv1d.bias',\n", + " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", + " ('model.layers.8.mixer.out_proj.weight',\n", + " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", + " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", + " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", + " ...,\n", + " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", + " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", + " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", + " ('model.layers.8.mlp.gate_proj.weight',\n", + " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", + " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", + " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", + " ...,\n", + " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", + " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", + " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", + " ('model.layers.8.mlp.up_proj.weight',\n", + " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", + " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", + " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", + " ...,\n", + " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", + " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", + " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", + " ('model.layers.8.mlp.down_proj.weight',\n", + " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", + " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", + " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", + " ...,\n", + " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", + " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", + " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", + " ('model.layers.8.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.9.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.in_proj.weight',\n", + " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", + " 7.1750e-04, 1.5247e-02],\n", + " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", + " -1.1202e-02, -2.1476e-02],\n", + " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", + " 1.8544e-02, -2.3633e-02],\n", + " ...,\n", + " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", + " -5.6073e-03, -4.8212e-03],\n", + " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", + " 4.5898e-02, -3.5822e-02],\n", + " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", + " 1.3560e-02, 3.3366e-02]])),\n", + " ('model.layers.9.mixer.conv1d.weight',\n", + " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", + " \n", + " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", + " \n", + " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", + " \n", + " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", + " \n", + " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", + " ('model.layers.9.mixer.conv1d.bias',\n", + " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", + " ('model.layers.9.mixer.out_proj.weight',\n", + " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", + " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", + " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", + " ...,\n", + " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", + " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", + " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", + " ('model.layers.9.mlp.gate_proj.weight',\n", + " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", + " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", + " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", + " ...,\n", + " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", + " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", + " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", + " ('model.layers.9.mlp.up_proj.weight',\n", + " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", + " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", + " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", + " ...,\n", + " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", + " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", + " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", + " ('model.layers.9.mlp.down_proj.weight',\n", + " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", + " -3.0707e-02, -1.3281e-02],\n", + " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", + " -1.1585e-03, -2.7604e-02],\n", + " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", + " 1.1647e-02, 1.1245e-02],\n", + " ...,\n", + " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", + " -4.5122e-03, 4.0012e-03],\n", + " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", + " 2.3422e-02, -3.2098e-02],\n", + " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", + " 2.2118e-02, 1.0696e-02]])),\n", + " ('model.layers.9.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.10.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.in_proj.weight',\n", + " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", + " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", + " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", + " ...,\n", + " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", + " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", + " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", + " ('model.layers.10.mixer.conv1d.weight',\n", + " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", + " \n", + " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", + " \n", + " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", + " \n", + " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", + " \n", + " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", + " ('model.layers.10.mixer.conv1d.bias',\n", + " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", + " ('model.layers.10.mixer.out_proj.weight',\n", + " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", + " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", + " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", + " ...,\n", + " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", + " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", + " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", + " ('model.layers.10.mlp.gate_proj.weight',\n", + " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", + " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", + " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", + " ...,\n", + " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", + " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", + " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", + " ('model.layers.10.mlp.up_proj.weight',\n", + " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", + " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", + " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", + " ...,\n", + " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", + " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", + " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", + " ('model.layers.10.mlp.down_proj.weight',\n", + " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", + " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", + " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", + " ...,\n", + " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", + " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", + " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", + " ('model.layers.10.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.11.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.in_proj.weight',\n", + " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", + " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", + " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", + " ...,\n", + " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", + " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", + " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", + " ('model.layers.11.mixer.conv1d.weight',\n", + " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", + " \n", + " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", + " \n", + " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", + " \n", + " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", + " \n", + " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", + " ('model.layers.11.mixer.conv1d.bias',\n", + " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", + " ('model.layers.11.mixer.out_proj.weight',\n", + " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", + " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", + " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", + " ...,\n", + " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", + " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", + " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", + " ('model.layers.11.mlp.gate_proj.weight',\n", + " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", + " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", + " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", + " ...,\n", + " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", + " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", + " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", + " ('model.layers.11.mlp.up_proj.weight',\n", + " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", + " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", + " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", + " ...,\n", + " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", + " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", + " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", + " ('model.layers.11.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", + " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", + " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", + " ...,\n", + " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", + " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", + " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", + " ('model.layers.11.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.12.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.in_proj.weight',\n", + " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", + " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", + " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", + " ...,\n", + " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", + " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", + " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", + " ('model.layers.12.mixer.conv1d.weight',\n", + " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", + " \n", + " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", + " \n", + " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", + " \n", + " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", + " \n", + " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", + " ('model.layers.12.mixer.conv1d.bias',\n", + " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", + " ('model.layers.12.mixer.out_proj.weight',\n", + " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", + " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", + " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", + " ...,\n", + " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", + " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", + " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", + " ('model.layers.12.mlp.gate_proj.weight',\n", + " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", + " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", + " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", + " ...,\n", + " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", + " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", + " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", + " ('model.layers.12.mlp.up_proj.weight',\n", + " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", + " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", + " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", + " ...,\n", + " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", + " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", + " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", + " ('model.layers.12.mlp.down_proj.weight',\n", + " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", + " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", + " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", + " ...,\n", + " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", + " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", + " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", + " ('model.layers.12.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.13.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.in_proj.weight',\n", + " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", + " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", + " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", + " ...,\n", + " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", + " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", + " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", + " ('model.layers.13.mixer.conv1d.weight',\n", + " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", + " \n", + " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", + " \n", + " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", + " \n", + " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", + " \n", + " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", + " ('model.layers.13.mixer.conv1d.bias',\n", + " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", + " ('model.layers.13.mixer.out_proj.weight',\n", + " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", + " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", + " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", + " ...,\n", + " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", + " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", + " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", + " ('model.layers.13.mlp.gate_proj.weight',\n", + " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", + " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", + " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", + " ...,\n", + " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", + " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", + " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", + " ('model.layers.13.mlp.up_proj.weight',\n", + " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", + " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", + " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", + " ...,\n", + " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", + " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", + " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", + " ('model.layers.13.mlp.down_proj.weight',\n", + " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", + " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", + " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", + " ...,\n", + " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", + " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", + " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", + " ('model.layers.13.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.14.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.in_proj.weight',\n", + " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", + " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", + " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", + " ...,\n", + " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", + " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", + " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", + " ('model.layers.14.mixer.conv1d.weight',\n", + " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", + " \n", + " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", + " \n", + " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", + " \n", + " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", + " \n", + " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", + " ('model.layers.14.mixer.conv1d.bias',\n", + " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", + " ('model.layers.14.mixer.out_proj.weight',\n", + " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", + " -2.4306e-02, 5.7698e-03],\n", + " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", + " -1.4523e-02, 1.2266e-02],\n", + " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", + " 1.0787e-02, 8.4053e-03],\n", + " ...,\n", + " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", + " 1.5601e-02, -6.3227e-03],\n", + " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", + " 1.2259e-02, 5.5500e-02],\n", + " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", + " 1.8864e-02, -1.2508e-02]])),\n", + " ('model.layers.14.mlp.gate_proj.weight',\n", + " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", + " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", + " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", + " ...,\n", + " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", + " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", + " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", + " ('model.layers.14.mlp.up_proj.weight',\n", + " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", + " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", + " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", + " ...,\n", + " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", + " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", + " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", + " ('model.layers.14.mlp.down_proj.weight',\n", + " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", + " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", + " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", + " ...,\n", + " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", + " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", + " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", + " ('model.layers.14.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.15.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.in_proj.weight',\n", + " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", + " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", + " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", + " ...,\n", + " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", + " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", + " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", + " ('model.layers.15.mixer.conv1d.weight',\n", + " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", + " \n", + " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", + " \n", + " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", + " \n", + " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", + " \n", + " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", + " ('model.layers.15.mixer.conv1d.bias',\n", + " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", + " ('model.layers.15.mixer.out_proj.weight',\n", + " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", + " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", + " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", + " ...,\n", + " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", + " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", + " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", + " ('model.layers.15.mlp.gate_proj.weight',\n", + " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", + " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", + " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", + " ...,\n", + " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", + " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", + " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", + " ('model.layers.15.mlp.up_proj.weight',\n", + " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", + " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", + " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", + " ...,\n", + " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", + " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", + " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", + " ('model.layers.15.mlp.down_proj.weight',\n", + " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", + " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", + " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", + " ...,\n", + " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", + " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", + " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", + " ('model.layers.15.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.16.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.in_proj.weight',\n", + " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", + " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", + " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", + " ...,\n", + " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", + " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", + " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", + " ('model.layers.16.mixer.conv1d.weight',\n", + " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", + " \n", + " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", + " \n", + " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", + " \n", + " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", + " \n", + " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", + " ('model.layers.16.mixer.conv1d.bias',\n", + " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", + " ('model.layers.16.mixer.out_proj.weight',\n", + " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", + " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", + " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", + " ...,\n", + " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", + " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", + " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", + " ('model.layers.16.mlp.gate_proj.weight',\n", + " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", + " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", + " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", + " ...,\n", + " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", + " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", + " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", + " ('model.layers.16.mlp.up_proj.weight',\n", + " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", + " -1.2243e-02, -1.3005e-02],\n", + " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", + " -1.5193e-02, 2.6111e-03],\n", + " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", + " 1.0694e-02, 3.7682e-03],\n", + " ...,\n", + " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", + " -4.8128e-03, 1.6899e-02],\n", + " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", + " -5.5857e-03, 4.0449e-03],\n", + " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", + " -8.0242e-03, -1.0925e-02]])),\n", + " ('model.layers.16.mlp.down_proj.weight',\n", + " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", + " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", + " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", + " ...,\n", + " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", + " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", + " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", + " ('model.layers.16.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.17.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.in_proj.weight',\n", + " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", + " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", + " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", + " ...,\n", + " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", + " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", + " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", + " ('model.layers.17.mixer.conv1d.weight',\n", + " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", + " \n", + " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", + " \n", + " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", + " \n", + " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", + " \n", + " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", + " ('model.layers.17.mixer.conv1d.bias',\n", + " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", + " ('model.layers.17.mixer.out_proj.weight',\n", + " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", + " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", + " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", + " ...,\n", + " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", + " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", + " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", + " ('model.layers.17.mlp.gate_proj.weight',\n", + " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", + " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", + " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", + " ...,\n", + " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", + " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", + " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", + " ('model.layers.17.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", + " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", + " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", + " ...,\n", + " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", + " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", + " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", + " ('model.layers.17.mlp.down_proj.weight',\n", + " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", + " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", + " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", + " ...,\n", + " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", + " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", + " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", + " ('model.layers.17.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.18.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.in_proj.weight',\n", + " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", + " -6.8715e-03, -1.3565e-02],\n", + " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", + " -1.0968e-02, -8.0445e-03],\n", + " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", + " 2.0198e-03, -1.7366e-03],\n", + " ...,\n", + " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", + " 1.4452e-02, -2.5152e-02],\n", + " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", + " 3.8135e-05, 1.7263e-02],\n", + " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", + " -1.1129e-02, -9.6768e-03]])),\n", + " ('model.layers.18.mixer.conv1d.weight',\n", + " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", + " \n", + " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", + " \n", + " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", + " \n", + " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", + " \n", + " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", + " ('model.layers.18.mixer.conv1d.bias',\n", + " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", + " ('model.layers.18.mixer.out_proj.weight',\n", + " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", + " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", + " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", + " ...,\n", + " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", + " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", + " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", + " ('model.layers.18.mlp.gate_proj.weight',\n", + " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", + " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", + " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", + " ...,\n", + " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", + " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", + " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", + " ('model.layers.18.mlp.up_proj.weight',\n", + " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", + " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", + " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", + " ...,\n", + " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", + " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", + " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", + " ('model.layers.18.mlp.down_proj.weight',\n", + " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", + " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", + " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", + " ...,\n", + " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", + " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", + " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", + " ('model.layers.18.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.19.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.in_proj.weight',\n", + " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", + " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", + " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", + " ...,\n", + " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", + " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", + " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", + " ('model.layers.19.mixer.conv1d.weight',\n", + " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", + " \n", + " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", + " \n", + " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", + " \n", + " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", + " \n", + " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", + " ('model.layers.19.mixer.conv1d.bias',\n", + " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", + " ('model.layers.19.mixer.out_proj.weight',\n", + " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", + " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", + " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", + " ...,\n", + " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", + " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", + " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", + " ('model.layers.19.mlp.gate_proj.weight',\n", + " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", + " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", + " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", + " ...,\n", + " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", + " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", + " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", + " ('model.layers.19.mlp.up_proj.weight',\n", + " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", + " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", + " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", + " ...,\n", + " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", + " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", + " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", + " ('model.layers.19.mlp.down_proj.weight',\n", + " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", + " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", + " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", + " ...,\n", + " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", + " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", + " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", + " ('model.layers.19.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.20.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.in_proj.weight',\n", + " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", + " -2.3083e-02, -6.2962e-02],\n", + " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", + " 9.0683e-03, 2.1583e-04],\n", + " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", + " 3.2650e-03, -3.1967e-02],\n", + " ...,\n", + " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", + " 1.0566e-02, -4.9561e-03],\n", + " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", + " -7.5403e-03, -1.3599e-02],\n", + " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", + " -3.3432e-02, 3.8337e-03]])),\n", + " ('model.layers.20.mixer.conv1d.weight',\n", + " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", + " \n", + " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", + " \n", + " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", + " \n", + " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", + " \n", + " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", + " ('model.layers.20.mixer.conv1d.bias',\n", + " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", + " ('model.layers.20.mixer.out_proj.weight',\n", + " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", + " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", + " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", + " ...,\n", + " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", + " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", + " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", + " ('model.layers.20.mlp.gate_proj.weight',\n", + " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", + " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", + " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", + " ...,\n", + " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", + " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", + " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", + " ('model.layers.20.mlp.up_proj.weight',\n", + " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", + " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", + " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", + " ...,\n", + " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", + " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", + " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", + " ('model.layers.20.mlp.down_proj.weight',\n", + " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", + " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", + " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", + " ...,\n", + " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", + " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", + " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", + " ('model.layers.20.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.21.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.in_proj.weight',\n", + " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", + " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", + " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", + " ...,\n", + " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", + " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", + " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", + " ('model.layers.21.mixer.conv1d.weight',\n", + " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", + " \n", + " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", + " \n", + " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", + " \n", + " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", + " \n", + " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", + " ('model.layers.21.mixer.conv1d.bias',\n", + " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", + " ('model.layers.21.mixer.out_proj.weight',\n", + " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", + " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", + " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", + " ...,\n", + " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", + " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", + " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", + " ('model.layers.21.mlp.gate_proj.weight',\n", + " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", + " -5.8295e-03, -1.4284e-02],\n", + " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", + " -3.9076e-02, 5.3379e-03],\n", + " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", + " -2.5180e-02, 6.1512e-04],\n", + " ...,\n", + " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", + " -7.2165e-03, 1.6297e-02],\n", + " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", + " -7.7709e-03, 4.3964e-04],\n", + " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", + " 5.0683e-02, -2.5481e-02]])),\n", + " ('model.layers.21.mlp.up_proj.weight',\n", + " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", + " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", + " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", + " ...,\n", + " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", + " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", + " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", + " ('model.layers.21.mlp.down_proj.weight',\n", + " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", + " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", + " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", + " ...,\n", + " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", + " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", + " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", + " ('model.layers.21.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.22.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.in_proj.weight',\n", + " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", + " -1.5757e-02, -1.5315e-03],\n", + " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", + " 8.9105e-05, -3.8658e-03],\n", + " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", + " -1.6604e-02, -4.5245e-03],\n", + " ...,\n", + " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", + " 2.3792e-02, -5.5165e-03],\n", + " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", + " 5.2496e-03, -6.0055e-03],\n", + " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", + " -1.5446e-02, -2.7164e-02]])),\n", + " ('model.layers.22.mixer.conv1d.weight',\n", + " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", + " \n", + " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", + " \n", + " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", + " \n", + " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", + " \n", + " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", + " ('model.layers.22.mixer.conv1d.bias',\n", + " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", + " ('model.layers.22.mixer.out_proj.weight',\n", + " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", + " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", + " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", + " ...,\n", + " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", + " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", + " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", + " ('model.layers.22.mlp.gate_proj.weight',\n", + " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", + " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", + " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", + " ...,\n", + " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", + " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", + " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", + " ('model.layers.22.mlp.up_proj.weight',\n", + " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", + " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", + " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", + " ...,\n", + " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", + " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", + " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", + " ('model.layers.22.mlp.down_proj.weight',\n", + " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", + " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", + " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", + " ...,\n", + " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", + " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", + " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", + " ('model.layers.22.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.23.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.in_proj.weight',\n", + " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", + " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", + " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", + " ...,\n", + " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", + " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", + " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", + " ('model.layers.23.mixer.conv1d.weight',\n", + " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", + " \n", + " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", + " \n", + " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", + " \n", + " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", + " \n", + " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", + " ('model.layers.23.mixer.conv1d.bias',\n", + " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", + " ('model.layers.23.mixer.out_proj.weight',\n", + " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", + " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", + " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", + " ...,\n", + " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", + " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", + " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", + " ('model.layers.23.mlp.gate_proj.weight',\n", + " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", + " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", + " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", + " ...,\n", + " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", + " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", + " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", + " ('model.layers.23.mlp.up_proj.weight',\n", + " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", + " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", + " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", + " ...,\n", + " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", + " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", + " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", + " ('model.layers.23.mlp.down_proj.weight',\n", + " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", + " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", + " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", + " ...,\n", + " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", + " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", + " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", + " ('model.layers.23.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.24.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.in_proj.weight',\n", + " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", + " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", + " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", + " ...,\n", + " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", + " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", + " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", + " ('model.layers.24.mixer.conv1d.weight',\n", + " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", + " \n", + " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", + " \n", + " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", + " \n", + " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", + " \n", + " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", + " ('model.layers.24.mixer.conv1d.bias',\n", + " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", + " ('model.layers.24.mixer.out_proj.weight',\n", + " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", + " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", + " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", + " ...,\n", + " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", + " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", + " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", + " ('model.layers.24.mlp.gate_proj.weight',\n", + " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", + " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", + " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", + " ...,\n", + " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", + " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", + " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", + " ('model.layers.24.mlp.up_proj.weight',\n", + " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", + " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", + " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", + " ...,\n", + " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", + " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", + " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", + " ('model.layers.24.mlp.down_proj.weight',\n", + " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", + " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", + " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", + " ...,\n", + " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", + " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", + " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", + " ('model.layers.24.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.25.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.in_proj.weight',\n", + " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", + " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", + " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", + " ...,\n", + " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", + " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", + " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", + " ('model.layers.25.mixer.conv1d.weight',\n", + " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", + " \n", + " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", + " \n", + " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", + " \n", + " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", + " \n", + " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", + " ('model.layers.25.mixer.conv1d.bias',\n", + " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", + " ('model.layers.25.mixer.out_proj.weight',\n", + " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", + " -1.8566e-02, 1.3872e-02],\n", + " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", + " 7.3446e-03, -1.6368e-02],\n", + " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", + " -1.6554e-02, 7.2751e-03],\n", + " ...,\n", + " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", + " -1.0262e-02, -4.9292e-03],\n", + " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", + " 2.2522e-02, -6.0243e-03],\n", + " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", + " -5.7889e-03, 1.1855e-02]])),\n", + " ('model.layers.25.mlp.gate_proj.weight',\n", + " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", + " -1.2723e-02, -9.7756e-03],\n", + " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", + " -1.5727e-02, -9.8748e-03],\n", + " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", + " 1.2017e-04, -1.5734e-02],\n", + " ...,\n", + " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", + " 1.8795e-02, 1.0221e-02],\n", + " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", + " 8.5950e-05, 1.0542e-02],\n", + " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", + " -3.0686e-02, -9.6116e-03]])),\n", + " ('model.layers.25.mlp.up_proj.weight',\n", + " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", + " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", + " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", + " ...,\n", + " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", + " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", + " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", + " ('model.layers.25.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", + " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", + " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", + " ...,\n", + " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", + " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", + " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", + " ('model.layers.25.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.26.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.in_proj.weight',\n", + " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", + " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", + " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", + " ...,\n", + " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", + " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", + " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", + " ('model.layers.26.mixer.conv1d.weight',\n", + " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", + " \n", + " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", + " \n", + " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", + " \n", + " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", + " \n", + " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", + " ('model.layers.26.mixer.conv1d.bias',\n", + " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", + " ('model.layers.26.mixer.out_proj.weight',\n", + " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", + " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", + " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", + " ...,\n", + " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", + " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", + " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", + " ('model.layers.26.mlp.gate_proj.weight',\n", + " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", + " 1.3550e-02, -1.1203e-02],\n", + " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", + " 5.3891e-03, -2.8310e-02],\n", + " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", + " -2.8479e-03, -1.3105e-02],\n", + " ...,\n", + " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", + " -1.8015e-02, -1.2234e-02],\n", + " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", + " 1.7886e-02, -4.4699e-02],\n", + " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", + " 3.3468e-03, 2.7677e-02]])),\n", + " ('model.layers.26.mlp.up_proj.weight',\n", + " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", + " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", + " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", + " ...,\n", + " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", + " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", + " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", + " ('model.layers.26.mlp.down_proj.weight',\n", + " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", + " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", + " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", + " ...,\n", + " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", + " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", + " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", + " ('model.layers.26.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.27.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.in_proj.weight',\n", + " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", + " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", + " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", + " ...,\n", + " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", + " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", + " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", + " ('model.layers.27.mixer.conv1d.weight',\n", + " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", + " \n", + " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", + " \n", + " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", + " \n", + " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", + " \n", + " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", + " ('model.layers.27.mixer.conv1d.bias',\n", + " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", + " ('model.layers.27.mixer.out_proj.weight',\n", + " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", + " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", + " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", + " ...,\n", + " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", + " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", + " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", + " ('model.layers.27.mlp.gate_proj.weight',\n", + " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", + " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", + " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", + " ...,\n", + " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", + " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", + " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", + " ('model.layers.27.mlp.up_proj.weight',\n", + " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", + " 3.2845e-02, 2.1107e-02],\n", + " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", + " -1.1685e-02, 1.2504e-02],\n", + " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", + " -3.3912e-04, 3.0527e-02],\n", + " ...,\n", + " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", + " 1.9532e-02, 1.3339e-02],\n", + " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", + " 2.8262e-02, -1.2526e-02],\n", + " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", + " -2.0571e-03, -8.3943e-03]])),\n", + " ('model.layers.27.mlp.down_proj.weight',\n", + " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", + " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", + " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", + " ...,\n", + " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", + " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", + " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", + " ('model.layers.27.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('lm_head.weight',\n", + " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", + " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", + " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", + " ...,\n", + " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", + " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", + " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" ] }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "apriel_ssm_config" + "apriel_ssm.state_dict()" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "N params SSM: 5.660780512\n" + "N params SSM: 5.305533088\n" ] } ], @@ -246,7 +2222,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -258,10 +2234,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", - " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -279,7 +2255,7 @@ ")" ] }, - "execution_count": 11, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -291,7 +2267,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -300,7 +2276,7 @@ "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" ] }, - "execution_count": 12, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -311,7 +2287,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -323,10 +2299,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", - " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -344,7 +2320,7 @@ ")" ] }, - "execution_count": 13, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +2332,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -372,26 +2348,29 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:2714: UserWarning: `save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead.\n", - " warnings.warn(\n" + "ename": "NameError", + "evalue": "name 'apriel_ssm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" ] } ], "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm\",\n", + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", " save_config=True)\n" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -400,7 +2379,7 @@ "24" ] }, - "execution_count": 60, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -411,7 +2390,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -423,10 +2402,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=12320, bias=False)\n", - " (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -439,13 +2418,12 @@ " )\n", " )\n", " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", " )\n", " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", ")" ] }, - "execution_count": 10, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -463,7 +2441,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -482,30 +2460,30 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4688, -1.6641, 0.4609, ..., -7.1562, -3.7812, -5.9062],\n", - " [-3.5000, 1.4297, 4.3125, ..., -5.3438, -4.9375, -2.9844],\n", - " [-3.1094, 0.7930, 2.2969, ..., -3.1250, -4.1875, -2.1250],\n", + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", + " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", + " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", " ...,\n", - " [-5.3438, -3.0938, -3.9062, ..., -4.9062, -3.0000, -3.9688],\n", - " [-3.0625, -3.2188, 5.6562, ..., -2.7812, -2.5938, -6.6562],\n", - " [-1.8438, -1.7500, 5.9062, ..., -3.7188, -2.1250, -0.8281]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[ 1.2266, 0.5547, -1.1953, ..., 0.1089, -2.5781, 0.6328],\n", - " [-0.4395, 0.5938, -0.1562, ..., -0.6719, -0.6367, -0.3086],\n", - " [ 0.0077, 0.6680, -1.0703, ..., -3.6875, 0.2207, 0.1299],\n", + " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", + " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", + " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", + " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", + " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", " ...,\n", - " [-0.0703, 0.4551, 0.1104, ..., 1.3438, 1.3984, 1.1641],\n", - " [-0.0613, 1.9141, -0.5430, ..., -1.0312, -0.6680, 0.0518],\n", - " [-0.6172, 0.2148, -0.5977, ..., -1.2734, -0.1914, 2.2344]]],\n", + " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", + " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", + " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" ] }, - "execution_count": 56, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -520,6 +2498,94 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import enum" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "class SSMBlockType(str, enum.Enum):\n", + " \"\"\"\n", + " An enum for the available mamba types for the MLP layer.\n", + " \"\"\"\n", + "\n", + " mamba = \"m\"\n", + " mamba2_discrete = \"m2d\"\n", + " mamba2 = \"m2\"\n", + " transformer = \"t\"" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_values([, , , ])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'m' in SSMBlockType.__members__.values()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'m'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mm\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[43mSSMBlockType\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mname\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/enum.py:808\u001b[0m, in \u001b[0;36mEnumType.__getitem__\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 806\u001b[0m \u001b[38;5;124;03m Return the member matching `name`.\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 808\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_member_map_\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'm'" + ] + } + ], + "source": [ + "\"m\" == SSMBlockType[\"m\"].name\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'m2d'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SSMBlockType.mamba2_discrete.value" + ] } ], "metadata": { From 77ad39f7730f314d01c3c8f5da14f1ac8aabf5f4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 30 Apr 2025 20:21:08 +0000 Subject: [PATCH 043/122] add token-prediction loss coefficients --- fast_llm/layers/language_model/config.py | 10 ++++++++++ fast_llm/layers/language_model/head.py | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c99ee4f6a..c675361a2 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,6 +215,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): doc="May be used to freeze the output weights by setting their scale to zero.", hint=FieldHint.feature, ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -231,3 +237,7 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1153fb2c2..014a617cb 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -57,6 +57,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._loss_coefficient = ( + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = config.transformer.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor @@ -133,7 +136,7 @@ def forward( else: if self.training: # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient) # MTP: Return shared_hidden to be used by the next head. return shared_hidden From da9bf1a78efbf4a34af23b9ea55ee1b98343230c Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 13:54:27 +0000 Subject: [PATCH 044/122] eval apriel ssm --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 +++++++++++++++++++ .../models/ssm/external/eval/run_lm_eval.py | 6 ++ .../ssm/external/modeling_ssm_apriel.py | 55 ++++++++++------- 3 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py create mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py new file mode 100644 index 000000000..94537c331 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py new file mode 100644 index 000000000..af07869a8 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -0,0 +1,6 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py index d30d5b66c..5a1b8db42 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from .configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -35,12 +35,13 @@ class CustomMambaCausalLMOutput(ModelOutput): class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): """ AprielRMSNorm is equivalent to T5LayerNorm """ + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -58,14 +59,15 @@ def extra_repr(self): class AprielMLP(nn.Module): - def __init__(self, config): - super().__init__() + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -437,19 +439,21 @@ def convolutional_step(self, xBC, conv_state): class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int): - super().__init__() + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size self.mixer = DiscreteMamba2( d_model=config.hidden_size, layer_idx=layer_idx, **config.ssm_cfg, + **factory_kwargs, ) - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( self, hidden_states: torch.Tensor, inference_params=None, **kwargs @@ -598,16 +602,16 @@ class AprielSSMModel(AprielSSMPreTrainedModel): config: AprielSSMConfig """ - def __init__(self, config: AprielSSMConfig): - super().__init__(config) + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -664,11 +668,12 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config): - super().__init__(config) + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.model = AprielSSMModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) # Initialize weights and apply final processing self.post_init() @@ -722,6 +727,12 @@ def forward( last_hidden_state=outputs["last_hidden_state"], ) + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + __all__ = [ "AprielSSMForCausalLM", From ac4a5982d8ea8d04f8d8ebdb693e59f82f58e6b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 12:26:50 -0400 Subject: [PATCH 045/122] fix --- fast_llm/models/gpt/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9e28373b3..80c9caa23 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -214,7 +214,6 @@ def preprocess_meta( reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( TransformerKwargs.sequence_first, - TransformerKwargs.hidden_dims, TransformerKwargs.sequence_length, TransformerKwargs.sequence_q_dim, TransformerKwargs.sequence_k_dim, From 0c0e7d9cca6b3ca31b11d4764e9664263badc8cc Mon Sep 17 00:00:00 2001 From: Luke Nitish Kumar Date: Thu, 1 May 2025 13:23:47 -0400 Subject: [PATCH 046/122] adding check for missing `rope_type` (#246) --- fast_llm/models/gpt/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bc8bea266..d4d581cd7 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -407,7 +407,7 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (export_value,) = export_values - if export_value is None or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": + if export_value is None or export_value is MISSING or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7 elif rope_type == RotaryEmbeddingType.llama3: return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) From 97ba9d44554908260dd24e1bb452d6bc00eee11a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 14:04:02 -0400 Subject: [PATCH 047/122] Loss masking for distillation --- fast_llm/functional/cross_entropy.py | 40 ++++++++++++++------- fast_llm/functional/triton/cross_entropy.py | 7 ++++ fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/head.py | 21 ++++++++--- fast_llm/models/gpt/model.py | 8 +++-- 5 files changed, 58 insertions(+), 19 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 1eb6c8c04..53b3e59ba 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -8,9 +8,10 @@ from fast_llm.utils import Assert -def torch_cross_entropy_forward_backward( +def _torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -22,15 +23,25 @@ def torch_cross_entropy_forward_backward( TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. + if loss_mask is not None: + raise NotImplementedError(f"Torch cross-entropy from {target_format} doesn't support loss masking.") with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: if logits_scale_factor != 1.0: target = target * logits_scale_factor target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target - ).mean() + if loss_mask is None: + loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target + ) + else: + loss = ( + torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + ) + * loss_mask + ).mean() if grad_output is None: grad = None else: @@ -57,8 +68,8 @@ def _fused_softmax_base( return logits_norm, exp_logits, sum_exp_logits -# @torch.compile -def fused_softmax( +@torch.compile +def _fused_softmax( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 ) -> torch.Tensor: _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) @@ -66,9 +77,10 @@ def fused_softmax( @torch.compile -def fused_cross_entropy_forward_backward( +def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -85,7 +97,7 @@ def fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target = fused_softmax(target, logits_scale_factor, group) + target = _fused_softmax(target, logits_scale_factor, group) if target_format == TargetFormat.labels: target = target.unsqueeze(-1) @@ -101,8 +113,6 @@ def fused_cross_entropy_forward_backward( target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) target = (target - vocab_start_index) * target_mask else: - # TODO: Support masking - loss_mask = None # Target should be tensor-parallel already, no further manipulation needed. target_mask = None @@ -145,8 +155,8 @@ def fused_cross_entropy_forward_backward( _CROSS_ENTROPY_IMPLEMENTATIONS = { - CrossEntropyImpl.torch: torch_cross_entropy_forward_backward, - CrossEntropyImpl.fused: fused_cross_entropy_forward_backward, + CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward, + CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward, CrossEntropyImpl.triton: triton_cross_entropy_forward_backward, } @@ -154,6 +164,7 @@ def fused_cross_entropy_forward_backward( def cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, group: ProcessGroup | None = None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, @@ -169,12 +180,15 @@ def cross_entropy_forward_backward( if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) + assert loss_mask is None else: Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, target.shape) if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return fused_cross_entropy_forward_backward( + return _fused_cross_entropy_forward_backward( logits, target, grad_output, logits_scale_factor, target_format, group ) else: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index d825af034..02dc1ce78 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -57,12 +57,14 @@ def triton_cross_entropy_forward_backward_kernel( def triton_cross_entropy_from_distribution_forward_backward_kernel( logits_ptr, target_ptr, + loss_mask_ptr, grad_logits_ptr, losses_ptr, grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, + loss_mask_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -87,6 +89,8 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) + if loss_mask_ptr is not None: + loss_mask = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=0) if from_logits: if logits_scale_factor != 1.0: target *= logits_scale_factor @@ -110,6 +114,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -149,12 +154,14 @@ def triton_cross_entropy_forward_backward( triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, + loss_mask, grad_logits, losses, None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(0), target.stride(0), + None if loss_mask is None else loss_mask.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4fb471fb3..0371eff43 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -35,6 +35,7 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + loss_mask = "loss_mask" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 3b476f6a3..9b1dd4d8a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -146,6 +146,8 @@ def _forward_backward( if self._config.distillation_model is None else f"{self._config.distillation_model}_logits" ) + # Loss mask for distillation. (Labels are already masked.) + loss_mask = None if target is not None: if self._config.distillation_model is None: # MTP: Shift the labels @@ -160,9 +162,12 @@ def _forward_backward( else: # Target is reference model logits. target = target.flatten(0, -2) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) if self._sequence_parallel_logits: target = split_op(target, self._tensor_space.distributed.tensor_group, 0) + if loss_mask is not None: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) do_grad = target is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): @@ -174,7 +179,7 @@ def _forward_backward( output_weights = self._get_output_weights(kwargs) loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), target, output_weights, grad_output, kwargs, losses + ln_output.detach(), target, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -194,6 +199,7 @@ def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, target: torch.Tensor | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -201,7 +207,7 @@ def _logits_cross_entropy_forward_backward_split( ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if self._cross_entropy_splits is None or target is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, target, weight, grad_output, kwargs, losses + input_, target, loss_mask, weight, grad_output, kwargs, losses ) if target is None: # TODO: Make a proper way of returning the model output. @@ -214,12 +220,17 @@ def _logits_cross_entropy_forward_backward_split( grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) - for logit_input_, target_, logit_input_grad_ in zip( - logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) + for logit_input_, target_, loss_mask_, logit_input_grad_ in zip( + logit_input.split(split_size), + target.split(split_size), + [None] * self._cross_entropy_splits if loss_mask is None else loss_mask.split(split_size), + logit_input_grad.split(split_size), + strict=True, ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, target_, + loss_mask_, weight, grad_output, kwargs, @@ -240,6 +251,7 @@ def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, target: torch.Tensor | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -298,6 +310,7 @@ def _logits_cross_entropy_forward_backward( loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), target, + loss_mask, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, grad_output=grad_output, implementation=self._cross_entropy_impl, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 80c9caa23..9084fb40e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -313,11 +313,15 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset + loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - labels[start : end + 1, i] = -100 + loss_mask[start : end + 1, i] = False else: - labels[i, start : end + 1] = -100 + loss_mask[i, start : end + 1] = False + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From 231d5d82e535dbc805756604c420b366516584b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 14:33:53 -0400 Subject: [PATCH 048/122] test, misc --- fast_llm/functional/cross_entropy.py | 6 +-- tests/layers/test_lm_head.py | 71 ++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 53b3e59ba..34c69d797 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -23,8 +23,6 @@ def _torch_cross_entropy_forward_backward( TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. - if loss_mask is not None: - raise NotImplementedError(f"Torch cross-entropy from {target_format} doesn't support loss masking.") with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: @@ -40,7 +38,7 @@ def _torch_cross_entropy_forward_backward( torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" ) - * loss_mask + * loss_mask.unsqueeze(-1) ).mean() if grad_output is None: grad = None @@ -185,7 +183,7 @@ def cross_entropy_forward_backward( Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: - Assert.eq(loss_mask.shape, target.shape) + Assert.eq(loss_mask.shape, logits.shape[:-1]) if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 79101f340..14edecffd 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -25,6 +25,7 @@ def _lm_head( input_: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, *, # config:LanguageModelBaseConfig, rms_weight: torch.Tensor, @@ -43,7 +44,13 @@ def _lm_head( if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + if target.ndim == logits.ndim: + loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") + if loss_mask is not None: + loss = loss * loss_mask.unsqueeze(-1) + loss = loss.mean() + else: + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) loss.backward(torch.full_like(loss, grad_output)) return loss, z_loss @@ -58,22 +65,26 @@ def _lm_head( @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( - ("config_dict", "distributed_config_dict"), + ("config_dict", "distributed_config_dict", "loss_masking"), ( - ({}, {}), - ({}, {"training_dtype": DataType.bfloat16}), - ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}), - ({"sequence_first": True}, {}), - ({"logit_z_loss": 1e-3}, {}), - ({"logits_scale_factor": 5.0}, {}), - ({"tie_word_embeddings": False}, {}), - ({"prediction_heads": 2}, {}), + ({}, {}, False), + ({}, {"training_dtype": DataType.bfloat16}, False), + ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), + ({"sequence_first": True}, {}, False), + ({"logit_z_loss": 1e-3}, {}, False), + ({"logits_scale_factor": 5.0}, {}, False), + ({"tie_word_embeddings": False}, {}, False), + ({"prediction_heads": 2}, {}, False), + ({}, {}, True), + ({"distillation_model": "distillation"}, {}, False), + ({"distillation_model": "distillation"}, {}, True), ), ) def test_lm_head( cross_entropy_impl: CrossEntropyImpl, config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], + loss_masking: bool, ): config = GPTBaseModelConfig.from_dict( { @@ -99,17 +110,6 @@ def test_lm_head( sequence_first = config.sequence_first or ( config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 ) - target = torch.randint( - 0, - VOCAB_SIZE, - ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) - ), - dtype=torch.int64, - device=distributed.device, - ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( @@ -120,6 +120,34 @@ def test_lm_head( device=distributed.device, requires_grad=True, ) + label_shape = ( + (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + if sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + ) + if loss_masking: + loss_mask = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.bool, + device=distributed.device, + ) + else: + loss_mask = None + if config.distillation_model is None: + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask + else: + assert config.prediction_heads == 1 + target = torch.randn_like(input_) kwargs = { TransformerKwargs.sequence_first: sequence_first, LanguageModelKwargs.labels: target, @@ -173,6 +201,7 @@ def test_lm_head( if sequence_first else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] ), + loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, logit_scale_factor=config.logits_scale_factor, From 30a75b003360a37d907659668d4b4dcf9f61ac3a Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 20:33:41 +0000 Subject: [PATCH 049/122] eval apriel ssm --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 +++++++++++++++++++ .../models/ssm/external/eval/run_lm_eval.py | 6 ++ .../ssm/external/modeling_ssm_apriel.py | 55 ++++++++++------- 3 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py create mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py new file mode 100644 index 000000000..94537c331 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py new file mode 100644 index 000000000..af07869a8 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -0,0 +1,6 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py index d30d5b66c..5a1b8db42 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from .configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -35,12 +35,13 @@ class CustomMambaCausalLMOutput(ModelOutput): class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): """ AprielRMSNorm is equivalent to T5LayerNorm """ + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -58,14 +59,15 @@ def extra_repr(self): class AprielMLP(nn.Module): - def __init__(self, config): - super().__init__() + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -437,19 +439,21 @@ def convolutional_step(self, xBC, conv_state): class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int): - super().__init__() + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size self.mixer = DiscreteMamba2( d_model=config.hidden_size, layer_idx=layer_idx, **config.ssm_cfg, + **factory_kwargs, ) - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( self, hidden_states: torch.Tensor, inference_params=None, **kwargs @@ -598,16 +602,16 @@ class AprielSSMModel(AprielSSMPreTrainedModel): config: AprielSSMConfig """ - def __init__(self, config: AprielSSMConfig): - super().__init__(config) + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -664,11 +668,12 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config): - super().__init__(config) + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.model = AprielSSMModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) # Initialize weights and apply final processing self.post_init() @@ -722,6 +727,12 @@ def forward( last_hidden_state=outputs["last_hidden_state"], ) + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + __all__ = [ "AprielSSMForCausalLM", From a50bc2e1e414b89718f72086056bdbd1cbc01106 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 20:38:37 +0000 Subject: [PATCH 050/122] cleanup --- fast_llm/layers/ssm/mamba2.py | 354 --- .../models/ssm/external/ariel_to_ssm.ipynb | 2612 ----------------- .../models/ssm/external/discrete_mamba2.py | 382 --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 - .../models/ssm/external/eval/run_lm_eval.py | 6 - 5 files changed, 3413 deletions(-) delete mode 100644 fast_llm/layers/ssm/mamba2.py delete mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb delete mode 100644 fast_llm/models/ssm/external/discrete_mamba2.py delete mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py delete mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py deleted file mode 100644 index 5763cb92c..000000000 --- a/fast_llm/layers/ssm/mamba2.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -This code is adapted from https://github.com/jxiw/MambaInLlama/blob/main/mamba2/hybrid_mamba_layer.py -""" - -import math - -import causal_conv1d -import einops -import mamba_ssm.ops.triton.ssd_combined -import torch -from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import kaiming_init_ - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Mamba2(torch.nn.Module): - def __init__( - self, - config: SSMConfig, - layer_idx: int, - tensor_space: TensorSpace, - ): - # factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.config: SSMConfig = config - bias = config.add_bias_linear - self.layer_idx = layer_idx - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - tensor_space.get_tensor_dim(SSMDimNames.v_heads) - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) - - # self.d_model = d_model - # self.d_state = d_state - # self.d_conv = d_conv - # self.conv_init = conv_init - # self.expand = expand - # self.process_group = process_group - # self.sequence_parallel = sequence_parallel - # self.world_size = 1 if process_group is None else process_group.size() - # self.local_rank = 0 if process_group is None else process_group.rank() - # self.d_inner = d_inner if d_inner is not None else (self.expand * self.d_model) // self.world_size - # # assert self.d_inner * self.world_size == self.expand * self.d_model - # self.headdim = headdim - # self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size - # assert ngroups % self.world_size == 0 - # self.ngroups = ngroups // self.world_size - # assert self.d_ssm % self.headdim == 0 - # self.nheads = self.d_ssm // self.headdim - # self.D_has_hdim = D_has_hdim - # self.rmsnorm = rmsnorm - # self.norm_before_gate = norm_before_gate - # self.dt_limit = dt_limit - # self.activation = "silu" - # self.chunk_size = chunk_size - # self.use_mem_eff_path = use_mem_eff_path - # self.layer_idx = layer_idx - # self.d_xb = d_xb - # self.repeat_group = self.d_inner // self.d_xb - # self.repeat_kv_before_conv = repeat_kv_before_conv - - assert self.d_inner == self.ngroups * self.d_state - assert self.d_inner == self.d_ssm - - self.nheads = self.ngroups - self.headdim = self.d_state - - # Order: [z, x, B, C, dt] - # [hidden_dim, hidden_dim, d_state] - d_in_proj = self.d_inner + self.d_xb + self.d_xb + self.d_inner + self.nheads - # d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - if self.process_group is None: - self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) - else: - self.in_proj = ColumnParallelLinear( - self.d_model, - d_in_proj * self.world_size, - bias=bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - **factory_kwargs, - ) - - # conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state - - if self.repeat_kv_before_conv: - conv_dim = self.d_inner + self.d_inner + self.d_inner - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - else: - conv_dim = self.d_inner + self.d_xb + self.d_xb - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - if self.conv_init is not None: - nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) - - self.act = nn.SiLU() - - # Initialize log dt bias - dt = torch.exp( - torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Just to be explicit. Without this we already don't put wd on dt_bias because of the check - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - - assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] - A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) - A_log = torch.log(A).to(dtype=dtype) - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) - self.D._no_weight_decay = True - - if self.rmsnorm: - assert RMSNormGated is not None - self.norm = RMSNormGated( - self.d_ssm, - eps=1e-5, - norm_before_gate=self.norm_before_gate, - group_size=self.d_ssm // ngroups, - **factory_kwargs, - ) - - # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), - ) - - def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): - """ - u: (batch, seqlen, hidden_dim) if seqlen=None. - If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we - split u during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - Returns: same shape as u - """ - seqlen_og = seqlen - if seqlen is None: - batch, seqlen, dim = u.shape - else: - batch_seqlen, dim = u.shape - batch = batch_seqlen // seqlen - - conv_state, ssm_state = None, None - if inference_params is not None: - inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch - conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(u, conv_state, ssm_state) - return out - - zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) - if seqlen_og is not None: - zxbcdt = einops.rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) - # If the model is loaded in fp16, without the .float() here, A might be -inf - A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) - dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) - - # [z, x, B, C, dt] - d_mlp = (zxbcdt.shape[-1] - 2 * self.d_inner - 2 * self.d_xb - self.nheads) // 2 - z0, x0, z, xBC, dt = torch.split( - zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.d_xb, self.nheads], dim=-1 - ) - - if self.repeat_kv_before_conv: - x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) - # minic the GQA - x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - # x shape: (bsz, n_group, l, dim) - B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) - # combine x, B, C - x = einops.rearrange(x, "b g l p -> b l (g p)") - B = einops.rearrange(B, "b g l p -> b l (g p)") - xBC = torch.cat((x, B, C), dim=-1) - - if conv_state is not None: - if cu_seqlens is None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC, "b l d -> b d l") - conv_state.copy_( - torch.nn.functional.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - else: - assert ( - causal_conv1d.causal_conv1d_varlen_states is not None - ), "varlen inference requires causal_conv1d package" - assert batch == 1, "varlen inference only supports batch dimension 1" - conv_varlen_states = causal_conv1d.causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] - ) - conv_state.copy_(conv_varlen_states) - assert self.activation in ["silu", "swish"] - - if causal_conv1d.causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: - assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" - xBC = self.act( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1) :] - ) # (B, L, self.d_ssm + 2 * ngroups * d_state) - else: - xBC = causal_conv1d.causal_conv1d_fn( - xBC.transpose(1, 2), - einops.rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - seq_idx=seq_idx, - ).transpose(1, 2) - - if self.repeat_kv_before_conv: - x, B, C = torch.split( - xBC, [self.ngroups * self.d_state, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1 - ) - - y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( - einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), - dt, - A, - einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), - einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), - chunk_size=self.chunk_size, - D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, - z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, - dt_bias=self.dt_bias, - dt_softplus=True, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - **dt_limit_kwargs, - return_final_states=ssm_state is not None, - return_varlen_states=cu_seqlens is not None and inference_params is not None, - ) - - else: - # self.d_xb + self.d_xb + self.d_inner - x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) - - # minic the GQA - x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - # x shape: (bsz, n_group, l, dim) - - B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) - - y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( - # einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), - einops.rearrange(x, "b g l p -> b l g p"), - dt, - A, - # einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), - einops.rearrange(B, "b g l n -> b l g n"), - einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), - chunk_size=self.chunk_size, - D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, - z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, - dt_bias=self.dt_bias, - dt_softplus=True, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - **dt_limit_kwargs, - return_final_states=ssm_state is not None, - return_varlen_states=cu_seqlens is not None and inference_params is not None, - ) - - if ssm_state is not None: - y, last_state, *rest = y - if cu_seqlens is None: - ssm_state.copy_(last_state) - else: - varlen_states = rest[0] - ssm_state.copy_(varlen_states) - y = einops.rearrange(y, "b l h p -> b l (h p)") - if self.rmsnorm: - y = self.norm(y, z) - if d_mlp > 0: - y = torch.cat([torch.nn.functional.silu(z0) * x0, y], dim=-1) - if seqlen_og is not None: - y = einops.rearrange(y, "b l d -> (b l) d") - out = self.out_proj(y) - return out - - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ).transpose(1, 2) - ssm_state = torch.zeros( - batch_size, - self.nheads, - self.headdim, - self.d_state, - device=self.in_proj.weight.device, - dtype=self.in_proj.weight.dtype, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb deleted file mode 100644 index a8390fa3d..000000000 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ /dev/null @@ -1,2612 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.bfloat16" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_model.config.torch_dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.83207168" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "n_params/1e9" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" - ] - } - ], - "source": [ - "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'n_qk_heads'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" - ] - } - ], - "source": [ - "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", - "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('model.embed_tokens.weight',\n", - " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", - " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", - " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", - " ...,\n", - " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", - " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", - " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", - " ('model.layers.0.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.0.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.0.mixer.in_proj.weight',\n", - " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", - " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", - " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", - " ...,\n", - " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", - " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", - " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", - " ('model.layers.0.mixer.conv1d.weight',\n", - " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", - " \n", - " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", - " \n", - " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", - " \n", - " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", - " \n", - " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", - " ('model.layers.0.mixer.conv1d.bias',\n", - " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", - " ('model.layers.0.mixer.out_proj.weight',\n", - " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", - " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", - " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", - " ...,\n", - " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", - " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", - " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", - " ('model.layers.0.mlp.gate_proj.weight',\n", - " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", - " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", - " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", - " ...,\n", - " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", - " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", - " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", - " ('model.layers.0.mlp.up_proj.weight',\n", - " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", - " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", - " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", - " ...,\n", - " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", - " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", - " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", - " ('model.layers.0.mlp.down_proj.weight',\n", - " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", - " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", - " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", - " ...,\n", - " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", - " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", - " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", - " ('model.layers.0.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.0.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.1.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.in_proj.weight',\n", - " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", - " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", - " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", - " ...,\n", - " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", - " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", - " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", - " ('model.layers.1.mixer.conv1d.weight',\n", - " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", - " \n", - " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", - " \n", - " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", - " \n", - " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", - " \n", - " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", - " ('model.layers.1.mixer.conv1d.bias',\n", - " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", - " ('model.layers.1.mixer.out_proj.weight',\n", - " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", - " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", - " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", - " ...,\n", - " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", - " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", - " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", - " ('model.layers.1.mlp.gate_proj.weight',\n", - " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", - " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", - " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", - " ...,\n", - " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", - " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", - " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", - " ('model.layers.1.mlp.up_proj.weight',\n", - " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", - " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", - " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", - " ...,\n", - " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", - " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", - " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", - " ('model.layers.1.mlp.down_proj.weight',\n", - " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", - " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", - " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", - " ...,\n", - " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", - " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", - " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", - " ('model.layers.1.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.2.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.in_proj.weight',\n", - " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", - " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", - " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", - " ...,\n", - " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", - " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", - " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", - " ('model.layers.2.mixer.conv1d.weight',\n", - " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", - " \n", - " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", - " \n", - " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", - " \n", - " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", - " \n", - " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", - " ('model.layers.2.mixer.conv1d.bias',\n", - " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", - " ('model.layers.2.mixer.out_proj.weight',\n", - " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", - " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", - " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", - " ...,\n", - " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", - " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", - " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", - " ('model.layers.2.mlp.gate_proj.weight',\n", - " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", - " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", - " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", - " ...,\n", - " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", - " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", - " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", - " ('model.layers.2.mlp.up_proj.weight',\n", - " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", - " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", - " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", - " ...,\n", - " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", - " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", - " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", - " ('model.layers.2.mlp.down_proj.weight',\n", - " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", - " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", - " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", - " ...,\n", - " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", - " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", - " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", - " ('model.layers.2.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.3.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.in_proj.weight',\n", - " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", - " -2.1648e-03, 2.8147e-03],\n", - " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", - " 4.9989e-03, 6.4200e-03],\n", - " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", - " 9.7000e-03, 3.1136e-05],\n", - " ...,\n", - " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", - " 4.6152e-04, 5.5341e-02],\n", - " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", - " -1.0645e-02, -1.8442e-03],\n", - " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", - " -4.7119e-03, 3.2352e-02]])),\n", - " ('model.layers.3.mixer.conv1d.weight',\n", - " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", - " \n", - " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", - " \n", - " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", - " \n", - " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", - " \n", - " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", - " ('model.layers.3.mixer.conv1d.bias',\n", - " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", - " ('model.layers.3.mixer.out_proj.weight',\n", - " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", - " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", - " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", - " ...,\n", - " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", - " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", - " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", - " ('model.layers.3.mlp.gate_proj.weight',\n", - " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", - " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", - " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", - " ...,\n", - " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", - " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", - " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", - " ('model.layers.3.mlp.up_proj.weight',\n", - " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", - " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", - " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", - " ...,\n", - " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", - " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", - " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", - " ('model.layers.3.mlp.down_proj.weight',\n", - " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", - " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", - " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", - " ...,\n", - " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", - " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", - " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", - " ('model.layers.3.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.4.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.in_proj.weight',\n", - " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", - " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", - " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", - " ...,\n", - " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", - " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", - " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", - " ('model.layers.4.mixer.conv1d.weight',\n", - " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", - " \n", - " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", - " \n", - " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", - " \n", - " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", - " \n", - " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", - " ('model.layers.4.mixer.conv1d.bias',\n", - " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", - " ('model.layers.4.mixer.out_proj.weight',\n", - " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", - " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", - " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", - " ...,\n", - " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", - " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", - " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", - " ('model.layers.4.mlp.gate_proj.weight',\n", - " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", - " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", - " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", - " ...,\n", - " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", - " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", - " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", - " ('model.layers.4.mlp.up_proj.weight',\n", - " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", - " 3.6365e-02, -8.1017e-03],\n", - " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", - " -3.1505e-02, 2.1047e-03],\n", - " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", - " -2.7856e-04, -8.7440e-04],\n", - " ...,\n", - " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", - " -1.9620e-02, 2.5112e-02],\n", - " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", - " -1.6838e-03, 2.0189e-02],\n", - " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", - " -1.3502e-02, -1.8923e-02]])),\n", - " ('model.layers.4.mlp.down_proj.weight',\n", - " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", - " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", - " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", - " ...,\n", - " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", - " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", - " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", - " ('model.layers.4.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.5.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.in_proj.weight',\n", - " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", - " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", - " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", - " ...,\n", - " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", - " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", - " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", - " ('model.layers.5.mixer.conv1d.weight',\n", - " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", - " \n", - " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", - " \n", - " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", - " \n", - " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", - " \n", - " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", - " ('model.layers.5.mixer.conv1d.bias',\n", - " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", - " ('model.layers.5.mixer.out_proj.weight',\n", - " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", - " 7.0706e-03, -4.7970e-03],\n", - " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", - " 2.2802e-02, -2.9729e-03],\n", - " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", - " -4.2582e-03, 2.4007e-02],\n", - " ...,\n", - " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", - " -2.0391e-02, 5.3912e-03],\n", - " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", - " -2.0596e-02, -8.2915e-03],\n", - " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", - " -3.8566e-02, 2.7005e-02]])),\n", - " ('model.layers.5.mlp.gate_proj.weight',\n", - " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", - " 3.6159e-02, -1.7253e-02],\n", - " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", - " 4.4487e-03, -9.6723e-03],\n", - " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", - " 1.2872e-02, -1.0443e-02],\n", - " ...,\n", - " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", - " 2.8752e-02, -8.0048e-03],\n", - " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", - " -3.2179e-02, -1.6437e-02],\n", - " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", - " 7.7639e-03, -2.2494e-03]])),\n", - " ('model.layers.5.mlp.up_proj.weight',\n", - " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", - " -1.8538e-02, 6.7395e-03],\n", - " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", - " -3.0374e-02, 2.5795e-02],\n", - " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", - " 1.4115e-02, -2.6438e-02],\n", - " ...,\n", - " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", - " -4.9858e-02, -1.1575e-02],\n", - " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", - " 1.8568e-02, 5.2615e-03],\n", - " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", - " 1.2643e-02, -1.1568e-05]])),\n", - " ('model.layers.5.mlp.down_proj.weight',\n", - " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", - " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", - " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", - " ...,\n", - " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", - " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", - " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", - " ('model.layers.5.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.6.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.in_proj.weight',\n", - " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", - " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", - " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", - " ...,\n", - " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", - " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", - " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", - " ('model.layers.6.mixer.conv1d.weight',\n", - " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", - " \n", - " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", - " \n", - " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", - " \n", - " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", - " \n", - " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", - " ('model.layers.6.mixer.conv1d.bias',\n", - " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", - " ('model.layers.6.mixer.out_proj.weight',\n", - " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", - " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", - " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", - " ...,\n", - " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", - " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", - " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", - " ('model.layers.6.mlp.gate_proj.weight',\n", - " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", - " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", - " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", - " ...,\n", - " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", - " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", - " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", - " ('model.layers.6.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", - " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", - " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", - " ...,\n", - " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", - " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", - " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", - " ('model.layers.6.mlp.down_proj.weight',\n", - " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", - " 4.8023e-03, 9.8474e-03],\n", - " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", - " -9.5275e-04, -1.8121e-02],\n", - " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", - " 2.8131e-03, 5.8219e-03],\n", - " ...,\n", - " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", - " -8.3965e-03, 2.0074e-02],\n", - " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", - " 3.4947e-02, -2.2158e-02],\n", - " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", - " 4.0735e-05, 3.3651e-02]])),\n", - " ('model.layers.6.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.7.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.in_proj.weight',\n", - " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", - " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", - " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", - " ...,\n", - " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", - " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", - " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", - " ('model.layers.7.mixer.conv1d.weight',\n", - " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", - " \n", - " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", - " \n", - " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", - " \n", - " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", - " \n", - " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", - " ('model.layers.7.mixer.conv1d.bias',\n", - " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", - " ('model.layers.7.mixer.out_proj.weight',\n", - " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", - " 1.4225e-02, 2.8166e-02],\n", - " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", - " 3.7454e-03, 1.7579e-02],\n", - " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", - " 1.4766e-02, -1.8002e-02],\n", - " ...,\n", - " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", - " -1.3072e-02, -1.0412e-02],\n", - " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", - " -1.1047e-02, 1.5815e-02],\n", - " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", - " -2.1834e-02, -2.1826e-02]])),\n", - " ('model.layers.7.mlp.gate_proj.weight',\n", - " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", - " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", - " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", - " ...,\n", - " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", - " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", - " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", - " ('model.layers.7.mlp.up_proj.weight',\n", - " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", - " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", - " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", - " ...,\n", - " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", - " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", - " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", - " ('model.layers.7.mlp.down_proj.weight',\n", - " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", - " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", - " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", - " ...,\n", - " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", - " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", - " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", - " ('model.layers.7.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.8.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.in_proj.weight',\n", - " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", - " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", - " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", - " ...,\n", - " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", - " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", - " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", - " ('model.layers.8.mixer.conv1d.weight',\n", - " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", - " \n", - " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", - " \n", - " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", - " \n", - " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", - " \n", - " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", - " ('model.layers.8.mixer.conv1d.bias',\n", - " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", - " ('model.layers.8.mixer.out_proj.weight',\n", - " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", - " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", - " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", - " ...,\n", - " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", - " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", - " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", - " ('model.layers.8.mlp.gate_proj.weight',\n", - " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", - " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", - " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", - " ...,\n", - " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", - " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", - " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", - " ('model.layers.8.mlp.up_proj.weight',\n", - " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", - " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", - " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", - " ...,\n", - " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", - " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", - " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", - " ('model.layers.8.mlp.down_proj.weight',\n", - " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", - " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", - " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", - " ...,\n", - " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", - " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", - " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", - " ('model.layers.8.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.9.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.in_proj.weight',\n", - " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", - " 7.1750e-04, 1.5247e-02],\n", - " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", - " -1.1202e-02, -2.1476e-02],\n", - " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", - " 1.8544e-02, -2.3633e-02],\n", - " ...,\n", - " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", - " -5.6073e-03, -4.8212e-03],\n", - " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", - " 4.5898e-02, -3.5822e-02],\n", - " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", - " 1.3560e-02, 3.3366e-02]])),\n", - " ('model.layers.9.mixer.conv1d.weight',\n", - " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", - " \n", - " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", - " \n", - " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", - " \n", - " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", - " \n", - " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", - " ('model.layers.9.mixer.conv1d.bias',\n", - " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", - " ('model.layers.9.mixer.out_proj.weight',\n", - " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", - " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", - " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", - " ...,\n", - " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", - " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", - " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", - " ('model.layers.9.mlp.gate_proj.weight',\n", - " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", - " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", - " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", - " ...,\n", - " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", - " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", - " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", - " ('model.layers.9.mlp.up_proj.weight',\n", - " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", - " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", - " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", - " ...,\n", - " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", - " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", - " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", - " ('model.layers.9.mlp.down_proj.weight',\n", - " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", - " -3.0707e-02, -1.3281e-02],\n", - " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", - " -1.1585e-03, -2.7604e-02],\n", - " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", - " 1.1647e-02, 1.1245e-02],\n", - " ...,\n", - " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", - " -4.5122e-03, 4.0012e-03],\n", - " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", - " 2.3422e-02, -3.2098e-02],\n", - " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", - " 2.2118e-02, 1.0696e-02]])),\n", - " ('model.layers.9.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.10.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.in_proj.weight',\n", - " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", - " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", - " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", - " ...,\n", - " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", - " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", - " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", - " ('model.layers.10.mixer.conv1d.weight',\n", - " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", - " \n", - " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", - " \n", - " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", - " \n", - " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", - " \n", - " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", - " ('model.layers.10.mixer.conv1d.bias',\n", - " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", - " ('model.layers.10.mixer.out_proj.weight',\n", - " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", - " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", - " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", - " ...,\n", - " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", - " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", - " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", - " ('model.layers.10.mlp.gate_proj.weight',\n", - " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", - " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", - " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", - " ...,\n", - " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", - " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", - " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", - " ('model.layers.10.mlp.up_proj.weight',\n", - " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", - " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", - " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", - " ...,\n", - " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", - " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", - " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", - " ('model.layers.10.mlp.down_proj.weight',\n", - " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", - " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", - " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", - " ...,\n", - " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", - " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", - " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", - " ('model.layers.10.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.11.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.in_proj.weight',\n", - " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", - " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", - " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", - " ...,\n", - " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", - " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", - " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", - " ('model.layers.11.mixer.conv1d.weight',\n", - " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", - " \n", - " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", - " \n", - " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", - " \n", - " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", - " \n", - " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", - " ('model.layers.11.mixer.conv1d.bias',\n", - " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", - " ('model.layers.11.mixer.out_proj.weight',\n", - " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", - " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", - " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", - " ...,\n", - " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", - " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", - " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", - " ('model.layers.11.mlp.gate_proj.weight',\n", - " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", - " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", - " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", - " ...,\n", - " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", - " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", - " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", - " ('model.layers.11.mlp.up_proj.weight',\n", - " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", - " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", - " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", - " ...,\n", - " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", - " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", - " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", - " ('model.layers.11.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", - " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", - " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", - " ...,\n", - " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", - " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", - " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", - " ('model.layers.11.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.12.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.in_proj.weight',\n", - " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", - " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", - " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", - " ...,\n", - " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", - " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", - " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", - " ('model.layers.12.mixer.conv1d.weight',\n", - " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", - " \n", - " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", - " \n", - " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", - " \n", - " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", - " \n", - " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", - " ('model.layers.12.mixer.conv1d.bias',\n", - " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", - " ('model.layers.12.mixer.out_proj.weight',\n", - " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", - " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", - " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", - " ...,\n", - " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", - " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", - " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", - " ('model.layers.12.mlp.gate_proj.weight',\n", - " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", - " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", - " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", - " ...,\n", - " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", - " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", - " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", - " ('model.layers.12.mlp.up_proj.weight',\n", - " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", - " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", - " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", - " ...,\n", - " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", - " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", - " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", - " ('model.layers.12.mlp.down_proj.weight',\n", - " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", - " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", - " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", - " ...,\n", - " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", - " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", - " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", - " ('model.layers.12.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.13.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.in_proj.weight',\n", - " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", - " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", - " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", - " ...,\n", - " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", - " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", - " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", - " ('model.layers.13.mixer.conv1d.weight',\n", - " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", - " \n", - " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", - " \n", - " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", - " \n", - " ...,\n", - " \n", - " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", - " \n", - " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", - " \n", - " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", - " ('model.layers.13.mixer.conv1d.bias',\n", - " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", - " ('model.layers.13.mixer.out_proj.weight',\n", - " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", - " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", - " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", - " ...,\n", - " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", - " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", - " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", - " ('model.layers.13.mlp.gate_proj.weight',\n", - " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", - " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", - " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", - " ...,\n", - " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", - " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", - " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", - " ('model.layers.13.mlp.up_proj.weight',\n", - " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", - " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", - " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", - " ...,\n", - " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", - " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", - " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", - " ('model.layers.13.mlp.down_proj.weight',\n", - " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", - " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", - " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", - " ...,\n", - " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", - " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", - " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", - " ('model.layers.13.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.14.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.in_proj.weight',\n", - " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", - " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", - " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", - " ...,\n", - " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", - " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", - " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", - " ('model.layers.14.mixer.conv1d.weight',\n", - " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", - " \n", - " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", - " \n", - " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", - " \n", - " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", - " \n", - " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", - " ('model.layers.14.mixer.conv1d.bias',\n", - " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", - " ('model.layers.14.mixer.out_proj.weight',\n", - " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", - " -2.4306e-02, 5.7698e-03],\n", - " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", - " -1.4523e-02, 1.2266e-02],\n", - " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", - " 1.0787e-02, 8.4053e-03],\n", - " ...,\n", - " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", - " 1.5601e-02, -6.3227e-03],\n", - " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", - " 1.2259e-02, 5.5500e-02],\n", - " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", - " 1.8864e-02, -1.2508e-02]])),\n", - " ('model.layers.14.mlp.gate_proj.weight',\n", - " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", - " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", - " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", - " ...,\n", - " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", - " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", - " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", - " ('model.layers.14.mlp.up_proj.weight',\n", - " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", - " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", - " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", - " ...,\n", - " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", - " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", - " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", - " ('model.layers.14.mlp.down_proj.weight',\n", - " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", - " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", - " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", - " ...,\n", - " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", - " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", - " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", - " ('model.layers.14.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.15.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.in_proj.weight',\n", - " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", - " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", - " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", - " ...,\n", - " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", - " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", - " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", - " ('model.layers.15.mixer.conv1d.weight',\n", - " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", - " \n", - " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", - " \n", - " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", - " \n", - " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", - " \n", - " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", - " ('model.layers.15.mixer.conv1d.bias',\n", - " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", - " ('model.layers.15.mixer.out_proj.weight',\n", - " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", - " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", - " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", - " ...,\n", - " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", - " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", - " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", - " ('model.layers.15.mlp.gate_proj.weight',\n", - " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", - " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", - " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", - " ...,\n", - " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", - " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", - " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", - " ('model.layers.15.mlp.up_proj.weight',\n", - " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", - " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", - " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", - " ...,\n", - " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", - " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", - " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", - " ('model.layers.15.mlp.down_proj.weight',\n", - " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", - " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", - " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", - " ...,\n", - " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", - " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", - " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", - " ('model.layers.15.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.16.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.in_proj.weight',\n", - " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", - " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", - " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", - " ...,\n", - " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", - " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", - " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", - " ('model.layers.16.mixer.conv1d.weight',\n", - " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", - " \n", - " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", - " \n", - " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", - " \n", - " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", - " \n", - " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", - " ('model.layers.16.mixer.conv1d.bias',\n", - " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", - " ('model.layers.16.mixer.out_proj.weight',\n", - " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", - " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", - " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", - " ...,\n", - " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", - " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", - " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", - " ('model.layers.16.mlp.gate_proj.weight',\n", - " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", - " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", - " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", - " ...,\n", - " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", - " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", - " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", - " ('model.layers.16.mlp.up_proj.weight',\n", - " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", - " -1.2243e-02, -1.3005e-02],\n", - " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", - " -1.5193e-02, 2.6111e-03],\n", - " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", - " 1.0694e-02, 3.7682e-03],\n", - " ...,\n", - " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", - " -4.8128e-03, 1.6899e-02],\n", - " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", - " -5.5857e-03, 4.0449e-03],\n", - " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", - " -8.0242e-03, -1.0925e-02]])),\n", - " ('model.layers.16.mlp.down_proj.weight',\n", - " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", - " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", - " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", - " ...,\n", - " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", - " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", - " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", - " ('model.layers.16.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.17.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.in_proj.weight',\n", - " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", - " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", - " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", - " ...,\n", - " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", - " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", - " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", - " ('model.layers.17.mixer.conv1d.weight',\n", - " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", - " \n", - " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", - " \n", - " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", - " \n", - " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", - " \n", - " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", - " ('model.layers.17.mixer.conv1d.bias',\n", - " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", - " ('model.layers.17.mixer.out_proj.weight',\n", - " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", - " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", - " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", - " ...,\n", - " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", - " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", - " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", - " ('model.layers.17.mlp.gate_proj.weight',\n", - " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", - " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", - " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", - " ...,\n", - " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", - " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", - " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", - " ('model.layers.17.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", - " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", - " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", - " ...,\n", - " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", - " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", - " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", - " ('model.layers.17.mlp.down_proj.weight',\n", - " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", - " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", - " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", - " ...,\n", - " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", - " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", - " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", - " ('model.layers.17.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.18.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.in_proj.weight',\n", - " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", - " -6.8715e-03, -1.3565e-02],\n", - " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", - " -1.0968e-02, -8.0445e-03],\n", - " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", - " 2.0198e-03, -1.7366e-03],\n", - " ...,\n", - " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", - " 1.4452e-02, -2.5152e-02],\n", - " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", - " 3.8135e-05, 1.7263e-02],\n", - " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", - " -1.1129e-02, -9.6768e-03]])),\n", - " ('model.layers.18.mixer.conv1d.weight',\n", - " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", - " \n", - " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", - " \n", - " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", - " \n", - " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", - " \n", - " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", - " ('model.layers.18.mixer.conv1d.bias',\n", - " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", - " ('model.layers.18.mixer.out_proj.weight',\n", - " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", - " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", - " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", - " ...,\n", - " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", - " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", - " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", - " ('model.layers.18.mlp.gate_proj.weight',\n", - " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", - " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", - " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", - " ...,\n", - " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", - " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", - " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", - " ('model.layers.18.mlp.up_proj.weight',\n", - " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", - " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", - " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", - " ...,\n", - " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", - " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", - " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", - " ('model.layers.18.mlp.down_proj.weight',\n", - " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", - " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", - " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", - " ...,\n", - " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", - " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", - " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", - " ('model.layers.18.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.19.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.in_proj.weight',\n", - " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", - " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", - " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", - " ...,\n", - " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", - " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", - " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", - " ('model.layers.19.mixer.conv1d.weight',\n", - " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", - " \n", - " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", - " \n", - " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", - " \n", - " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", - " \n", - " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", - " ('model.layers.19.mixer.conv1d.bias',\n", - " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", - " ('model.layers.19.mixer.out_proj.weight',\n", - " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", - " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", - " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", - " ...,\n", - " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", - " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", - " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", - " ('model.layers.19.mlp.gate_proj.weight',\n", - " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", - " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", - " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", - " ...,\n", - " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", - " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", - " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", - " ('model.layers.19.mlp.up_proj.weight',\n", - " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", - " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", - " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", - " ...,\n", - " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", - " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", - " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", - " ('model.layers.19.mlp.down_proj.weight',\n", - " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", - " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", - " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", - " ...,\n", - " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", - " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", - " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", - " ('model.layers.19.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.20.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.in_proj.weight',\n", - " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", - " -2.3083e-02, -6.2962e-02],\n", - " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", - " 9.0683e-03, 2.1583e-04],\n", - " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", - " 3.2650e-03, -3.1967e-02],\n", - " ...,\n", - " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", - " 1.0566e-02, -4.9561e-03],\n", - " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", - " -7.5403e-03, -1.3599e-02],\n", - " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", - " -3.3432e-02, 3.8337e-03]])),\n", - " ('model.layers.20.mixer.conv1d.weight',\n", - " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", - " \n", - " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", - " \n", - " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", - " \n", - " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", - " \n", - " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", - " ('model.layers.20.mixer.conv1d.bias',\n", - " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", - " ('model.layers.20.mixer.out_proj.weight',\n", - " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", - " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", - " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", - " ...,\n", - " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", - " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", - " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", - " ('model.layers.20.mlp.gate_proj.weight',\n", - " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", - " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", - " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", - " ...,\n", - " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", - " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", - " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", - " ('model.layers.20.mlp.up_proj.weight',\n", - " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", - " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", - " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", - " ...,\n", - " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", - " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", - " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", - " ('model.layers.20.mlp.down_proj.weight',\n", - " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", - " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", - " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", - " ...,\n", - " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", - " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", - " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", - " ('model.layers.20.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.21.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.in_proj.weight',\n", - " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", - " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", - " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", - " ...,\n", - " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", - " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", - " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", - " ('model.layers.21.mixer.conv1d.weight',\n", - " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", - " \n", - " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", - " \n", - " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", - " \n", - " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", - " \n", - " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", - " ('model.layers.21.mixer.conv1d.bias',\n", - " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", - " ('model.layers.21.mixer.out_proj.weight',\n", - " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", - " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", - " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", - " ...,\n", - " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", - " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", - " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", - " ('model.layers.21.mlp.gate_proj.weight',\n", - " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", - " -5.8295e-03, -1.4284e-02],\n", - " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", - " -3.9076e-02, 5.3379e-03],\n", - " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", - " -2.5180e-02, 6.1512e-04],\n", - " ...,\n", - " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", - " -7.2165e-03, 1.6297e-02],\n", - " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", - " -7.7709e-03, 4.3964e-04],\n", - " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", - " 5.0683e-02, -2.5481e-02]])),\n", - " ('model.layers.21.mlp.up_proj.weight',\n", - " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", - " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", - " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", - " ...,\n", - " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", - " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", - " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", - " ('model.layers.21.mlp.down_proj.weight',\n", - " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", - " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", - " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", - " ...,\n", - " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", - " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", - " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", - " ('model.layers.21.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.22.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.in_proj.weight',\n", - " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", - " -1.5757e-02, -1.5315e-03],\n", - " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", - " 8.9105e-05, -3.8658e-03],\n", - " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", - " -1.6604e-02, -4.5245e-03],\n", - " ...,\n", - " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", - " 2.3792e-02, -5.5165e-03],\n", - " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", - " 5.2496e-03, -6.0055e-03],\n", - " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", - " -1.5446e-02, -2.7164e-02]])),\n", - " ('model.layers.22.mixer.conv1d.weight',\n", - " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", - " \n", - " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", - " \n", - " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", - " \n", - " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", - " \n", - " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", - " ('model.layers.22.mixer.conv1d.bias',\n", - " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", - " ('model.layers.22.mixer.out_proj.weight',\n", - " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", - " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", - " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", - " ...,\n", - " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", - " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", - " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", - " ('model.layers.22.mlp.gate_proj.weight',\n", - " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", - " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", - " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", - " ...,\n", - " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", - " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", - " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", - " ('model.layers.22.mlp.up_proj.weight',\n", - " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", - " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", - " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", - " ...,\n", - " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", - " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", - " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", - " ('model.layers.22.mlp.down_proj.weight',\n", - " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", - " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", - " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", - " ...,\n", - " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", - " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", - " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", - " ('model.layers.22.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.23.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.in_proj.weight',\n", - " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", - " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", - " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", - " ...,\n", - " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", - " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", - " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", - " ('model.layers.23.mixer.conv1d.weight',\n", - " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", - " \n", - " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", - " \n", - " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", - " \n", - " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", - " \n", - " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", - " ('model.layers.23.mixer.conv1d.bias',\n", - " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", - " ('model.layers.23.mixer.out_proj.weight',\n", - " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", - " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", - " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", - " ...,\n", - " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", - " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", - " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", - " ('model.layers.23.mlp.gate_proj.weight',\n", - " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", - " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", - " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", - " ...,\n", - " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", - " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", - " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", - " ('model.layers.23.mlp.up_proj.weight',\n", - " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", - " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", - " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", - " ...,\n", - " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", - " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", - " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", - " ('model.layers.23.mlp.down_proj.weight',\n", - " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", - " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", - " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", - " ...,\n", - " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", - " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", - " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", - " ('model.layers.23.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.24.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.in_proj.weight',\n", - " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", - " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", - " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", - " ...,\n", - " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", - " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", - " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", - " ('model.layers.24.mixer.conv1d.weight',\n", - " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", - " \n", - " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", - " \n", - " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", - " \n", - " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", - " \n", - " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", - " ('model.layers.24.mixer.conv1d.bias',\n", - " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", - " ('model.layers.24.mixer.out_proj.weight',\n", - " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", - " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", - " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", - " ...,\n", - " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", - " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", - " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", - " ('model.layers.24.mlp.gate_proj.weight',\n", - " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", - " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", - " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", - " ...,\n", - " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", - " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", - " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", - " ('model.layers.24.mlp.up_proj.weight',\n", - " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", - " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", - " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", - " ...,\n", - " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", - " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", - " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", - " ('model.layers.24.mlp.down_proj.weight',\n", - " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", - " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", - " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", - " ...,\n", - " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", - " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", - " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", - " ('model.layers.24.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.25.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.in_proj.weight',\n", - " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", - " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", - " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", - " ...,\n", - " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", - " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", - " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", - " ('model.layers.25.mixer.conv1d.weight',\n", - " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", - " \n", - " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", - " \n", - " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", - " \n", - " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", - " \n", - " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", - " ('model.layers.25.mixer.conv1d.bias',\n", - " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", - " ('model.layers.25.mixer.out_proj.weight',\n", - " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", - " -1.8566e-02, 1.3872e-02],\n", - " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", - " 7.3446e-03, -1.6368e-02],\n", - " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", - " -1.6554e-02, 7.2751e-03],\n", - " ...,\n", - " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", - " -1.0262e-02, -4.9292e-03],\n", - " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", - " 2.2522e-02, -6.0243e-03],\n", - " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", - " -5.7889e-03, 1.1855e-02]])),\n", - " ('model.layers.25.mlp.gate_proj.weight',\n", - " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", - " -1.2723e-02, -9.7756e-03],\n", - " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", - " -1.5727e-02, -9.8748e-03],\n", - " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", - " 1.2017e-04, -1.5734e-02],\n", - " ...,\n", - " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", - " 1.8795e-02, 1.0221e-02],\n", - " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", - " 8.5950e-05, 1.0542e-02],\n", - " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", - " -3.0686e-02, -9.6116e-03]])),\n", - " ('model.layers.25.mlp.up_proj.weight',\n", - " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", - " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", - " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", - " ...,\n", - " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", - " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", - " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", - " ('model.layers.25.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", - " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", - " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", - " ...,\n", - " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", - " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", - " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", - " ('model.layers.25.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.26.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.in_proj.weight',\n", - " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", - " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", - " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", - " ...,\n", - " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", - " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", - " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", - " ('model.layers.26.mixer.conv1d.weight',\n", - " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", - " \n", - " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", - " \n", - " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", - " \n", - " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", - " \n", - " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", - " ('model.layers.26.mixer.conv1d.bias',\n", - " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", - " ('model.layers.26.mixer.out_proj.weight',\n", - " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", - " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", - " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", - " ...,\n", - " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", - " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", - " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", - " ('model.layers.26.mlp.gate_proj.weight',\n", - " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", - " 1.3550e-02, -1.1203e-02],\n", - " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", - " 5.3891e-03, -2.8310e-02],\n", - " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", - " -2.8479e-03, -1.3105e-02],\n", - " ...,\n", - " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", - " -1.8015e-02, -1.2234e-02],\n", - " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", - " 1.7886e-02, -4.4699e-02],\n", - " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", - " 3.3468e-03, 2.7677e-02]])),\n", - " ('model.layers.26.mlp.up_proj.weight',\n", - " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", - " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", - " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", - " ...,\n", - " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", - " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", - " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", - " ('model.layers.26.mlp.down_proj.weight',\n", - " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", - " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", - " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", - " ...,\n", - " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", - " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", - " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", - " ('model.layers.26.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.27.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.in_proj.weight',\n", - " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", - " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", - " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", - " ...,\n", - " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", - " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", - " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", - " ('model.layers.27.mixer.conv1d.weight',\n", - " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", - " \n", - " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", - " \n", - " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", - " \n", - " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", - " \n", - " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", - " ('model.layers.27.mixer.conv1d.bias',\n", - " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", - " ('model.layers.27.mixer.out_proj.weight',\n", - " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", - " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", - " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", - " ...,\n", - " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", - " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", - " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", - " ('model.layers.27.mlp.gate_proj.weight',\n", - " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", - " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", - " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", - " ...,\n", - " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", - " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", - " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", - " ('model.layers.27.mlp.up_proj.weight',\n", - " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", - " 3.2845e-02, 2.1107e-02],\n", - " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", - " -1.1685e-02, 1.2504e-02],\n", - " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", - " -3.3912e-04, 3.0527e-02],\n", - " ...,\n", - " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", - " 1.9532e-02, 1.3339e-02],\n", - " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", - " 2.8262e-02, -1.2526e-02],\n", - " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", - " -2.0571e-03, -8.3943e-03]])),\n", - " ('model.layers.27.mlp.down_proj.weight',\n", - " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", - " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", - " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", - " ...,\n", - " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", - " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", - " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", - " ('model.layers.27.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('lm_head.weight',\n", - " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", - " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", - " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", - " ...,\n", - " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", - " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", - " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "N params SSM: 5.305533088\n" - ] - } - ], - "source": [ - "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load State dict into SSM" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'apriel_ssm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" - ] - } - ], - "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", - " save_config=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "24" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.model.layers[0].mixer.n_v_heads" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Try a forward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", - " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", - " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", - " ...,\n", - " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", - " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", - " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", - " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", - " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", - " ...,\n", - " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", - " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", - " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import enum" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "class SSMBlockType(str, enum.Enum):\n", - " \"\"\"\n", - " An enum for the available mamba types for the MLP layer.\n", - " \"\"\"\n", - "\n", - " mamba = \"m\"\n", - " mamba2_discrete = \"m2d\"\n", - " mamba2 = \"m2\"\n", - " transformer = \"t\"" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_values([, , , ])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'m' in SSMBlockType.__members__.values()" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "'m'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mm\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[43mSSMBlockType\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mname\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/enum.py:808\u001b[0m, in \u001b[0;36mEnumType.__getitem__\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 806\u001b[0m \u001b[38;5;124;03m Return the member matching `name`.\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 808\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_member_map_\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'm'" - ] - } - ], - "source": [ - "\"m\" == SSMBlockType[\"m\"].name\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'m2d'" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "SSMBlockType.mamba2_discrete.value" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hymba2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_llm/models/ssm/external/discrete_mamba2.py b/fast_llm/models/ssm/external/discrete_mamba2.py deleted file mode 100644 index bb8afaa7d..000000000 --- a/fast_llm/models/ssm/external/discrete_mamba2.py +++ /dev/null @@ -1,382 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined - -from .configuration_mtp_llamba import StateUpdateKernel - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None, None - - -class DiscreteMamba2(nn.Module): - """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" - - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - verification_mode: StateUpdateKernel = StateUpdateKernel.cs, - **kwargs, - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". - - Other options are all experimental and should not need to be configured. - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - self.inference_mode = verification_mode - assert verification_mode in [ - StateUpdateKernel.cs, - StateUpdateKernel.standard, - ], "Only chunk scan and standard selective scan are supported for now" - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - @property - def d_output(self): - """Returns the output dimension of the model.""" - return self.d_model - - @property - def state_to_tensor(self): - """Returns the state of the model as a tensor.""" - return self.layer.state_to_tensor - - def forward(self, u, inference_params=None, **kwargs): - """ - Args: - u: (B, L, D), - inference_params: dict.. Here we assume it contains a mask tensor of shape (B, L) with 1s for valid tokens and 0s for no-op tokens. - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ - outputs = {} - # assert state is None - batch, seqlen, dim = u.shape - - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - - if ( - state is not None - and inference_params.seqlen_offset > 0 # meaning we are in the middle of the sequence - and seqlen == 1 - and self.inference_mode != StateUpdateKernel.cs - ): - # we go in here for standard 1 token per time-step inference. - # seqlen_offset > 0 means we are in the middle of a sequence - # States are updated inplace - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward( - xBC, padded_len, mask=inference_params.mask if inference_params is not None else None - ) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - # TODO: this kernel needs to be aupdated to use the mask! If used solely for throughout benchmarking, it is enough to call it as is. - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), - ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - return outputs - - def step(self, u, state, **kwargs): - """ - Args: - u: (B, D), - state: dict. - - Returns: - out: (B, D), - state: dict. - - """ - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - state["ssm"] = state["ssm"].to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=state["ssm"], # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate memory for inference cache.""" - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - Get states from cache. - - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) - # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] - if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states - - def convolutional_forward(self, xBC, padded_len, mask=None): - """Convolutional layer forward pass for the full sequence.""" - seqlen = xBC.shape[1] - mask_seql = -1 if mask is None else mask.shape[1] - # If seqlen != mask_seql, this likely means we preallocated mask for static generation, - # but here we are in the prefill phase. - # Note, mask is needed to prevent state upodate for no-op tokens as described in https://proceedings.mlr.press/v262/wu24a.html - # Note, if we want to use joint attanimnet and advancement in selective-scan mode, we would need to implement masking into the kernel of causal_conv1d_fn and mamba_chunk_scan_combined - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - else: - # TODO: note, this only works for chunked inference, for autoregressive mode we need to update the kernel to make sure conv state is not poluted - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - return xBC - - def convolutional_step(self, xBC, conv_state): - """Convolutional layer forward pass for a single step.""" - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py deleted file mode 100644 index 94537c331..000000000 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Union - -import lm_eval.models.utils -import torch -from lm_eval.api.registry import register_model -from lm_eval.models.huggingface import HFLM - - -@register_model("apriel_ssm") -class AprielSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" - - def __init__(self, pretrained, **kwargs) -> None: - if "backend" in kwargs: - # rene currently only supports causal models - assert kwargs["backend"] == "causal" - - super().__init__( - pretrained=pretrained, - backend=kwargs.pop("backend", "causal"), - tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), - max_length=kwargs.pop("max_length", 4096), - **kwargs, - ) - - def _get_config(self, pretrained: str, **kwargs) -> None: - """Get the model configuration.""" - from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig - - self._config = AprielSSMConfig.from_pretrained(pretrained) - - def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: - """Create the model.""" - from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM - - self._model = AprielSSMForCausalLM.from_pretrained( - pretrained, - device=self._device, - dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), - trust_remote_code=True, - ) - - def _model_generate(self, context, max_length, stop, **generation_kwargs): - """Generate text from the model.""" - for key in ("do_sample", "attention_mask"): - if key in generation_kwargs: - generation_kwargs.pop(key) - - # The custom GenerationMixin imported from mamba_ssm currently does not support - # passing stopping criteria. - # For the time being, we simply generate to max length, then truncate (equivalent result). - # This should be revisited to speed up generation - # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) - - return self.model.generate( - input_ids=context, - max_length=max_length, - **generation_kwargs, - ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py deleted file mode 100644 index af07869a8..000000000 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ /dev/null @@ -1,6 +0,0 @@ -from lm_eval.__main__ import cli_evaluate - -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 - -if __name__ == "__main__": - cli_evaluate() From 6532c5f94c3aea38018b2a06e413a59ad98fb4aa Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 12:23:57 +0000 Subject: [PATCH 051/122] hybrid config --- .../models/ssm/external/ariel_to_ssm.ipynb | 3526 +++++++++++++++++ .../configuration_ssm_hybrid_apriel.py | 446 +++ .../external/modeling_ssm_hybrid_apriel.py | 1203 ++++++ 3 files changed, 5175 insertions(+) create mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb create mode 100644 fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py create mode 100644 fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb new file mode 100644 index 000000000..496338cb0 --- /dev/null +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -0,0 +1,3526 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.bfloat16" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_model.config.torch_dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.83207168" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_params/1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" + ] + } + ], + "source": [ + "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'n_qk_heads'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" + ] + } + ], + "source": [ + "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", + "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", + " rms_norm_eps=config.rms_norm_eps)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('model.embed_tokens.weight',\n", + " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", + " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", + " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", + " ...,\n", + " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", + " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", + " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", + " ('model.layers.0.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.0.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.0.mixer.in_proj.weight',\n", + " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", + " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", + " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", + " ...,\n", + " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", + " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", + " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", + " ('model.layers.0.mixer.conv1d.weight',\n", + " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", + " \n", + " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", + " \n", + " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", + " \n", + " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", + " \n", + " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", + " ('model.layers.0.mixer.conv1d.bias',\n", + " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", + " ('model.layers.0.mixer.out_proj.weight',\n", + " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", + " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", + " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", + " ...,\n", + " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", + " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", + " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", + " ('model.layers.0.mlp.gate_proj.weight',\n", + " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", + " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", + " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", + " ...,\n", + " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", + " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", + " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", + " ('model.layers.0.mlp.up_proj.weight',\n", + " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", + " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", + " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", + " ...,\n", + " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", + " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", + " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", + " ('model.layers.0.mlp.down_proj.weight',\n", + " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", + " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", + " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", + " ...,\n", + " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", + " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", + " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", + " ('model.layers.0.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.0.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.1.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.in_proj.weight',\n", + " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", + " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", + " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", + " ...,\n", + " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", + " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", + " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", + " ('model.layers.1.mixer.conv1d.weight',\n", + " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", + " \n", + " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", + " \n", + " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", + " \n", + " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", + " \n", + " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", + " ('model.layers.1.mixer.conv1d.bias',\n", + " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", + " ('model.layers.1.mixer.out_proj.weight',\n", + " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", + " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", + " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", + " ...,\n", + " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", + " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", + " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", + " ('model.layers.1.mlp.gate_proj.weight',\n", + " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", + " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", + " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", + " ...,\n", + " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", + " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", + " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", + " ('model.layers.1.mlp.up_proj.weight',\n", + " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", + " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", + " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", + " ...,\n", + " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", + " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", + " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", + " ('model.layers.1.mlp.down_proj.weight',\n", + " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", + " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", + " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", + " ...,\n", + " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", + " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", + " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", + " ('model.layers.1.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.2.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.in_proj.weight',\n", + " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", + " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", + " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", + " ...,\n", + " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", + " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", + " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", + " ('model.layers.2.mixer.conv1d.weight',\n", + " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", + " \n", + " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", + " \n", + " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", + " \n", + " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", + " \n", + " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", + " ('model.layers.2.mixer.conv1d.bias',\n", + " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", + " ('model.layers.2.mixer.out_proj.weight',\n", + " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", + " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", + " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", + " ...,\n", + " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", + " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", + " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", + " ('model.layers.2.mlp.gate_proj.weight',\n", + " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", + " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", + " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", + " ...,\n", + " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", + " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", + " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", + " ('model.layers.2.mlp.up_proj.weight',\n", + " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", + " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", + " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", + " ...,\n", + " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", + " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", + " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", + " ('model.layers.2.mlp.down_proj.weight',\n", + " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", + " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", + " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", + " ...,\n", + " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", + " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", + " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", + " ('model.layers.2.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.3.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.in_proj.weight',\n", + " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", + " -2.1648e-03, 2.8147e-03],\n", + " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", + " 4.9989e-03, 6.4200e-03],\n", + " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", + " 9.7000e-03, 3.1136e-05],\n", + " ...,\n", + " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", + " 4.6152e-04, 5.5341e-02],\n", + " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", + " -1.0645e-02, -1.8442e-03],\n", + " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", + " -4.7119e-03, 3.2352e-02]])),\n", + " ('model.layers.3.mixer.conv1d.weight',\n", + " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", + " \n", + " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", + " \n", + " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", + " \n", + " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", + " \n", + " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", + " ('model.layers.3.mixer.conv1d.bias',\n", + " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", + " ('model.layers.3.mixer.out_proj.weight',\n", + " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", + " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", + " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", + " ...,\n", + " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", + " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", + " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", + " ('model.layers.3.mlp.gate_proj.weight',\n", + " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", + " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", + " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", + " ...,\n", + " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", + " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", + " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", + " ('model.layers.3.mlp.up_proj.weight',\n", + " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", + " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", + " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", + " ...,\n", + " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", + " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", + " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", + " ('model.layers.3.mlp.down_proj.weight',\n", + " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", + " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", + " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", + " ...,\n", + " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", + " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", + " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", + " ('model.layers.3.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.4.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.in_proj.weight',\n", + " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", + " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", + " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", + " ...,\n", + " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", + " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", + " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", + " ('model.layers.4.mixer.conv1d.weight',\n", + " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", + " \n", + " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", + " \n", + " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", + " \n", + " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", + " \n", + " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", + " ('model.layers.4.mixer.conv1d.bias',\n", + " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", + " ('model.layers.4.mixer.out_proj.weight',\n", + " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", + " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", + " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", + " ...,\n", + " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", + " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", + " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", + " ('model.layers.4.mlp.gate_proj.weight',\n", + " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", + " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", + " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", + " ...,\n", + " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", + " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", + " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", + " ('model.layers.4.mlp.up_proj.weight',\n", + " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", + " 3.6365e-02, -8.1017e-03],\n", + " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", + " -3.1505e-02, 2.1047e-03],\n", + " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", + " -2.7856e-04, -8.7440e-04],\n", + " ...,\n", + " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", + " -1.9620e-02, 2.5112e-02],\n", + " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", + " -1.6838e-03, 2.0189e-02],\n", + " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", + " -1.3502e-02, -1.8923e-02]])),\n", + " ('model.layers.4.mlp.down_proj.weight',\n", + " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", + " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", + " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", + " ...,\n", + " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", + " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", + " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", + " ('model.layers.4.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.5.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.in_proj.weight',\n", + " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", + " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", + " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", + " ...,\n", + " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", + " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", + " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", + " ('model.layers.5.mixer.conv1d.weight',\n", + " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", + " \n", + " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", + " \n", + " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", + " \n", + " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", + " \n", + " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", + " ('model.layers.5.mixer.conv1d.bias',\n", + " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", + " ('model.layers.5.mixer.out_proj.weight',\n", + " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", + " 7.0706e-03, -4.7970e-03],\n", + " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", + " 2.2802e-02, -2.9729e-03],\n", + " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", + " -4.2582e-03, 2.4007e-02],\n", + " ...,\n", + " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", + " -2.0391e-02, 5.3912e-03],\n", + " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", + " -2.0596e-02, -8.2915e-03],\n", + " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", + " -3.8566e-02, 2.7005e-02]])),\n", + " ('model.layers.5.mlp.gate_proj.weight',\n", + " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", + " 3.6159e-02, -1.7253e-02],\n", + " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", + " 4.4487e-03, -9.6723e-03],\n", + " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", + " 1.2872e-02, -1.0443e-02],\n", + " ...,\n", + " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", + " 2.8752e-02, -8.0048e-03],\n", + " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", + " -3.2179e-02, -1.6437e-02],\n", + " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", + " 7.7639e-03, -2.2494e-03]])),\n", + " ('model.layers.5.mlp.up_proj.weight',\n", + " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", + " -1.8538e-02, 6.7395e-03],\n", + " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", + " -3.0374e-02, 2.5795e-02],\n", + " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", + " 1.4115e-02, -2.6438e-02],\n", + " ...,\n", + " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", + " -4.9858e-02, -1.1575e-02],\n", + " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", + " 1.8568e-02, 5.2615e-03],\n", + " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", + " 1.2643e-02, -1.1568e-05]])),\n", + " ('model.layers.5.mlp.down_proj.weight',\n", + " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", + " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", + " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", + " ...,\n", + " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", + " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", + " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", + " ('model.layers.5.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.6.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.in_proj.weight',\n", + " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", + " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", + " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", + " ...,\n", + " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", + " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", + " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", + " ('model.layers.6.mixer.conv1d.weight',\n", + " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", + " \n", + " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", + " \n", + " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", + " \n", + " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", + " \n", + " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", + " ('model.layers.6.mixer.conv1d.bias',\n", + " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", + " ('model.layers.6.mixer.out_proj.weight',\n", + " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", + " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", + " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", + " ...,\n", + " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", + " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", + " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", + " ('model.layers.6.mlp.gate_proj.weight',\n", + " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", + " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", + " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", + " ...,\n", + " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", + " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", + " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", + " ('model.layers.6.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", + " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", + " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", + " ...,\n", + " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", + " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", + " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", + " ('model.layers.6.mlp.down_proj.weight',\n", + " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", + " 4.8023e-03, 9.8474e-03],\n", + " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", + " -9.5275e-04, -1.8121e-02],\n", + " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", + " 2.8131e-03, 5.8219e-03],\n", + " ...,\n", + " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", + " -8.3965e-03, 2.0074e-02],\n", + " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", + " 3.4947e-02, -2.2158e-02],\n", + " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", + " 4.0735e-05, 3.3651e-02]])),\n", + " ('model.layers.6.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.7.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.in_proj.weight',\n", + " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", + " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", + " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", + " ...,\n", + " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", + " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", + " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", + " ('model.layers.7.mixer.conv1d.weight',\n", + " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", + " \n", + " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", + " \n", + " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", + " \n", + " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", + " \n", + " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", + " ('model.layers.7.mixer.conv1d.bias',\n", + " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", + " ('model.layers.7.mixer.out_proj.weight',\n", + " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", + " 1.4225e-02, 2.8166e-02],\n", + " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", + " 3.7454e-03, 1.7579e-02],\n", + " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", + " 1.4766e-02, -1.8002e-02],\n", + " ...,\n", + " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", + " -1.3072e-02, -1.0412e-02],\n", + " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", + " -1.1047e-02, 1.5815e-02],\n", + " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", + " -2.1834e-02, -2.1826e-02]])),\n", + " ('model.layers.7.mlp.gate_proj.weight',\n", + " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", + " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", + " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", + " ...,\n", + " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", + " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", + " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", + " ('model.layers.7.mlp.up_proj.weight',\n", + " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", + " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", + " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", + " ...,\n", + " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", + " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", + " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", + " ('model.layers.7.mlp.down_proj.weight',\n", + " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", + " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", + " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", + " ...,\n", + " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", + " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", + " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", + " ('model.layers.7.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.8.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.in_proj.weight',\n", + " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", + " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", + " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", + " ...,\n", + " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", + " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", + " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", + " ('model.layers.8.mixer.conv1d.weight',\n", + " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", + " \n", + " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", + " \n", + " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", + " \n", + " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", + " \n", + " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", + " ('model.layers.8.mixer.conv1d.bias',\n", + " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", + " ('model.layers.8.mixer.out_proj.weight',\n", + " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", + " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", + " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", + " ...,\n", + " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", + " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", + " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", + " ('model.layers.8.mlp.gate_proj.weight',\n", + " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", + " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", + " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", + " ...,\n", + " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", + " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", + " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", + " ('model.layers.8.mlp.up_proj.weight',\n", + " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", + " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", + " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", + " ...,\n", + " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", + " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", + " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", + " ('model.layers.8.mlp.down_proj.weight',\n", + " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", + " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", + " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", + " ...,\n", + " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", + " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", + " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", + " ('model.layers.8.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.9.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.in_proj.weight',\n", + " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", + " 7.1750e-04, 1.5247e-02],\n", + " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", + " -1.1202e-02, -2.1476e-02],\n", + " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", + " 1.8544e-02, -2.3633e-02],\n", + " ...,\n", + " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", + " -5.6073e-03, -4.8212e-03],\n", + " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", + " 4.5898e-02, -3.5822e-02],\n", + " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", + " 1.3560e-02, 3.3366e-02]])),\n", + " ('model.layers.9.mixer.conv1d.weight',\n", + " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", + " \n", + " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", + " \n", + " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", + " \n", + " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", + " \n", + " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", + " ('model.layers.9.mixer.conv1d.bias',\n", + " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", + " ('model.layers.9.mixer.out_proj.weight',\n", + " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", + " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", + " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", + " ...,\n", + " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", + " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", + " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", + " ('model.layers.9.mlp.gate_proj.weight',\n", + " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", + " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", + " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", + " ...,\n", + " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", + " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", + " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", + " ('model.layers.9.mlp.up_proj.weight',\n", + " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", + " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", + " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", + " ...,\n", + " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", + " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", + " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", + " ('model.layers.9.mlp.down_proj.weight',\n", + " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", + " -3.0707e-02, -1.3281e-02],\n", + " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", + " -1.1585e-03, -2.7604e-02],\n", + " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", + " 1.1647e-02, 1.1245e-02],\n", + " ...,\n", + " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", + " -4.5122e-03, 4.0012e-03],\n", + " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", + " 2.3422e-02, -3.2098e-02],\n", + " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", + " 2.2118e-02, 1.0696e-02]])),\n", + " ('model.layers.9.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.10.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.in_proj.weight',\n", + " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", + " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", + " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", + " ...,\n", + " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", + " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", + " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", + " ('model.layers.10.mixer.conv1d.weight',\n", + " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", + " \n", + " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", + " \n", + " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", + " \n", + " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", + " \n", + " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", + " ('model.layers.10.mixer.conv1d.bias',\n", + " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", + " ('model.layers.10.mixer.out_proj.weight',\n", + " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", + " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", + " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", + " ...,\n", + " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", + " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", + " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", + " ('model.layers.10.mlp.gate_proj.weight',\n", + " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", + " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", + " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", + " ...,\n", + " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", + " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", + " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", + " ('model.layers.10.mlp.up_proj.weight',\n", + " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", + " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", + " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", + " ...,\n", + " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", + " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", + " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", + " ('model.layers.10.mlp.down_proj.weight',\n", + " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", + " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", + " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", + " ...,\n", + " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", + " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", + " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", + " ('model.layers.10.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.11.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.in_proj.weight',\n", + " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", + " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", + " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", + " ...,\n", + " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", + " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", + " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", + " ('model.layers.11.mixer.conv1d.weight',\n", + " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", + " \n", + " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", + " \n", + " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", + " \n", + " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", + " \n", + " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", + " ('model.layers.11.mixer.conv1d.bias',\n", + " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", + " ('model.layers.11.mixer.out_proj.weight',\n", + " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", + " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", + " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", + " ...,\n", + " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", + " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", + " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", + " ('model.layers.11.mlp.gate_proj.weight',\n", + " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", + " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", + " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", + " ...,\n", + " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", + " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", + " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", + " ('model.layers.11.mlp.up_proj.weight',\n", + " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", + " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", + " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", + " ...,\n", + " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", + " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", + " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", + " ('model.layers.11.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", + " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", + " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", + " ...,\n", + " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", + " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", + " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", + " ('model.layers.11.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.12.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.in_proj.weight',\n", + " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", + " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", + " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", + " ...,\n", + " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", + " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", + " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", + " ('model.layers.12.mixer.conv1d.weight',\n", + " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", + " \n", + " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", + " \n", + " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", + " \n", + " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", + " \n", + " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", + " ('model.layers.12.mixer.conv1d.bias',\n", + " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", + " ('model.layers.12.mixer.out_proj.weight',\n", + " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", + " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", + " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", + " ...,\n", + " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", + " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", + " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", + " ('model.layers.12.mlp.gate_proj.weight',\n", + " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", + " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", + " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", + " ...,\n", + " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", + " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", + " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", + " ('model.layers.12.mlp.up_proj.weight',\n", + " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", + " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", + " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", + " ...,\n", + " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", + " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", + " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", + " ('model.layers.12.mlp.down_proj.weight',\n", + " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", + " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", + " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", + " ...,\n", + " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", + " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", + " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", + " ('model.layers.12.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.13.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.in_proj.weight',\n", + " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", + " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", + " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", + " ...,\n", + " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", + " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", + " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", + " ('model.layers.13.mixer.conv1d.weight',\n", + " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", + " \n", + " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", + " \n", + " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", + " \n", + " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", + " \n", + " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", + " ('model.layers.13.mixer.conv1d.bias',\n", + " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", + " ('model.layers.13.mixer.out_proj.weight',\n", + " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", + " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", + " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", + " ...,\n", + " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", + " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", + " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", + " ('model.layers.13.mlp.gate_proj.weight',\n", + " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", + " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", + " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", + " ...,\n", + " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", + " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", + " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", + " ('model.layers.13.mlp.up_proj.weight',\n", + " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", + " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", + " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", + " ...,\n", + " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", + " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", + " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", + " ('model.layers.13.mlp.down_proj.weight',\n", + " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", + " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", + " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", + " ...,\n", + " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", + " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", + " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", + " ('model.layers.13.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.14.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.in_proj.weight',\n", + " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", + " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", + " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", + " ...,\n", + " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", + " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", + " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", + " ('model.layers.14.mixer.conv1d.weight',\n", + " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", + " \n", + " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", + " \n", + " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", + " \n", + " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", + " \n", + " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", + " ('model.layers.14.mixer.conv1d.bias',\n", + " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", + " ('model.layers.14.mixer.out_proj.weight',\n", + " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", + " -2.4306e-02, 5.7698e-03],\n", + " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", + " -1.4523e-02, 1.2266e-02],\n", + " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", + " 1.0787e-02, 8.4053e-03],\n", + " ...,\n", + " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", + " 1.5601e-02, -6.3227e-03],\n", + " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", + " 1.2259e-02, 5.5500e-02],\n", + " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", + " 1.8864e-02, -1.2508e-02]])),\n", + " ('model.layers.14.mlp.gate_proj.weight',\n", + " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", + " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", + " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", + " ...,\n", + " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", + " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", + " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", + " ('model.layers.14.mlp.up_proj.weight',\n", + " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", + " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", + " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", + " ...,\n", + " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", + " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", + " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", + " ('model.layers.14.mlp.down_proj.weight',\n", + " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", + " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", + " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", + " ...,\n", + " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", + " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", + " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", + " ('model.layers.14.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.15.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.in_proj.weight',\n", + " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", + " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", + " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", + " ...,\n", + " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", + " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", + " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", + " ('model.layers.15.mixer.conv1d.weight',\n", + " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", + " \n", + " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", + " \n", + " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", + " \n", + " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", + " \n", + " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", + " ('model.layers.15.mixer.conv1d.bias',\n", + " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", + " ('model.layers.15.mixer.out_proj.weight',\n", + " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", + " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", + " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", + " ...,\n", + " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", + " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", + " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", + " ('model.layers.15.mlp.gate_proj.weight',\n", + " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", + " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", + " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", + " ...,\n", + " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", + " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", + " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", + " ('model.layers.15.mlp.up_proj.weight',\n", + " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", + " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", + " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", + " ...,\n", + " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", + " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", + " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", + " ('model.layers.15.mlp.down_proj.weight',\n", + " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", + " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", + " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", + " ...,\n", + " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", + " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", + " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", + " ('model.layers.15.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.16.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.in_proj.weight',\n", + " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", + " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", + " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", + " ...,\n", + " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", + " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", + " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", + " ('model.layers.16.mixer.conv1d.weight',\n", + " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", + " \n", + " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", + " \n", + " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", + " \n", + " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", + " \n", + " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", + " ('model.layers.16.mixer.conv1d.bias',\n", + " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", + " ('model.layers.16.mixer.out_proj.weight',\n", + " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", + " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", + " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", + " ...,\n", + " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", + " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", + " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", + " ('model.layers.16.mlp.gate_proj.weight',\n", + " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", + " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", + " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", + " ...,\n", + " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", + " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", + " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", + " ('model.layers.16.mlp.up_proj.weight',\n", + " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", + " -1.2243e-02, -1.3005e-02],\n", + " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", + " -1.5193e-02, 2.6111e-03],\n", + " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", + " 1.0694e-02, 3.7682e-03],\n", + " ...,\n", + " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", + " -4.8128e-03, 1.6899e-02],\n", + " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", + " -5.5857e-03, 4.0449e-03],\n", + " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", + " -8.0242e-03, -1.0925e-02]])),\n", + " ('model.layers.16.mlp.down_proj.weight',\n", + " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", + " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", + " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", + " ...,\n", + " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", + " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", + " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", + " ('model.layers.16.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.17.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.in_proj.weight',\n", + " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", + " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", + " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", + " ...,\n", + " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", + " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", + " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", + " ('model.layers.17.mixer.conv1d.weight',\n", + " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", + " \n", + " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", + " \n", + " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", + " \n", + " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", + " \n", + " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", + " ('model.layers.17.mixer.conv1d.bias',\n", + " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", + " ('model.layers.17.mixer.out_proj.weight',\n", + " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", + " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", + " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", + " ...,\n", + " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", + " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", + " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", + " ('model.layers.17.mlp.gate_proj.weight',\n", + " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", + " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", + " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", + " ...,\n", + " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", + " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", + " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", + " ('model.layers.17.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", + " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", + " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", + " ...,\n", + " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", + " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", + " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", + " ('model.layers.17.mlp.down_proj.weight',\n", + " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", + " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", + " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", + " ...,\n", + " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", + " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", + " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", + " ('model.layers.17.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.18.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.in_proj.weight',\n", + " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", + " -6.8715e-03, -1.3565e-02],\n", + " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", + " -1.0968e-02, -8.0445e-03],\n", + " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", + " 2.0198e-03, -1.7366e-03],\n", + " ...,\n", + " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", + " 1.4452e-02, -2.5152e-02],\n", + " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", + " 3.8135e-05, 1.7263e-02],\n", + " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", + " -1.1129e-02, -9.6768e-03]])),\n", + " ('model.layers.18.mixer.conv1d.weight',\n", + " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", + " \n", + " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", + " \n", + " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", + " \n", + " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", + " \n", + " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", + " ('model.layers.18.mixer.conv1d.bias',\n", + " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", + " ('model.layers.18.mixer.out_proj.weight',\n", + " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", + " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", + " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", + " ...,\n", + " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", + " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", + " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", + " ('model.layers.18.mlp.gate_proj.weight',\n", + " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", + " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", + " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", + " ...,\n", + " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", + " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", + " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", + " ('model.layers.18.mlp.up_proj.weight',\n", + " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", + " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", + " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", + " ...,\n", + " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", + " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", + " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", + " ('model.layers.18.mlp.down_proj.weight',\n", + " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", + " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", + " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", + " ...,\n", + " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", + " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", + " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", + " ('model.layers.18.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.19.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.in_proj.weight',\n", + " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", + " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", + " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", + " ...,\n", + " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", + " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", + " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", + " ('model.layers.19.mixer.conv1d.weight',\n", + " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", + " \n", + " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", + " \n", + " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", + " \n", + " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", + " \n", + " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", + " ('model.layers.19.mixer.conv1d.bias',\n", + " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", + " ('model.layers.19.mixer.out_proj.weight',\n", + " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", + " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", + " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", + " ...,\n", + " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", + " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", + " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", + " ('model.layers.19.mlp.gate_proj.weight',\n", + " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", + " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", + " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", + " ...,\n", + " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", + " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", + " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", + " ('model.layers.19.mlp.up_proj.weight',\n", + " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", + " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", + " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", + " ...,\n", + " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", + " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", + " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", + " ('model.layers.19.mlp.down_proj.weight',\n", + " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", + " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", + " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", + " ...,\n", + " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", + " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", + " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", + " ('model.layers.19.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.20.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.in_proj.weight',\n", + " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", + " -2.3083e-02, -6.2962e-02],\n", + " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", + " 9.0683e-03, 2.1583e-04],\n", + " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", + " 3.2650e-03, -3.1967e-02],\n", + " ...,\n", + " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", + " 1.0566e-02, -4.9561e-03],\n", + " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", + " -7.5403e-03, -1.3599e-02],\n", + " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", + " -3.3432e-02, 3.8337e-03]])),\n", + " ('model.layers.20.mixer.conv1d.weight',\n", + " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", + " \n", + " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", + " \n", + " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", + " \n", + " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", + " \n", + " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", + " ('model.layers.20.mixer.conv1d.bias',\n", + " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", + " ('model.layers.20.mixer.out_proj.weight',\n", + " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", + " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", + " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", + " ...,\n", + " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", + " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", + " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", + " ('model.layers.20.mlp.gate_proj.weight',\n", + " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", + " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", + " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", + " ...,\n", + " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", + " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", + " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", + " ('model.layers.20.mlp.up_proj.weight',\n", + " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", + " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", + " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", + " ...,\n", + " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", + " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", + " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", + " ('model.layers.20.mlp.down_proj.weight',\n", + " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", + " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", + " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", + " ...,\n", + " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", + " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", + " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", + " ('model.layers.20.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.21.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.in_proj.weight',\n", + " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", + " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", + " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", + " ...,\n", + " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", + " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", + " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", + " ('model.layers.21.mixer.conv1d.weight',\n", + " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", + " \n", + " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", + " \n", + " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", + " \n", + " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", + " \n", + " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", + " ('model.layers.21.mixer.conv1d.bias',\n", + " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", + " ('model.layers.21.mixer.out_proj.weight',\n", + " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", + " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", + " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", + " ...,\n", + " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", + " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", + " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", + " ('model.layers.21.mlp.gate_proj.weight',\n", + " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", + " -5.8295e-03, -1.4284e-02],\n", + " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", + " -3.9076e-02, 5.3379e-03],\n", + " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", + " -2.5180e-02, 6.1512e-04],\n", + " ...,\n", + " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", + " -7.2165e-03, 1.6297e-02],\n", + " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", + " -7.7709e-03, 4.3964e-04],\n", + " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", + " 5.0683e-02, -2.5481e-02]])),\n", + " ('model.layers.21.mlp.up_proj.weight',\n", + " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", + " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", + " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", + " ...,\n", + " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", + " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", + " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", + " ('model.layers.21.mlp.down_proj.weight',\n", + " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", + " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", + " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", + " ...,\n", + " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", + " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", + " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", + " ('model.layers.21.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.22.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.in_proj.weight',\n", + " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", + " -1.5757e-02, -1.5315e-03],\n", + " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", + " 8.9105e-05, -3.8658e-03],\n", + " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", + " -1.6604e-02, -4.5245e-03],\n", + " ...,\n", + " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", + " 2.3792e-02, -5.5165e-03],\n", + " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", + " 5.2496e-03, -6.0055e-03],\n", + " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", + " -1.5446e-02, -2.7164e-02]])),\n", + " ('model.layers.22.mixer.conv1d.weight',\n", + " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", + " \n", + " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", + " \n", + " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", + " \n", + " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", + " \n", + " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", + " ('model.layers.22.mixer.conv1d.bias',\n", + " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", + " ('model.layers.22.mixer.out_proj.weight',\n", + " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", + " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", + " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", + " ...,\n", + " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", + " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", + " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", + " ('model.layers.22.mlp.gate_proj.weight',\n", + " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", + " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", + " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", + " ...,\n", + " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", + " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", + " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", + " ('model.layers.22.mlp.up_proj.weight',\n", + " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", + " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", + " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", + " ...,\n", + " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", + " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", + " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", + " ('model.layers.22.mlp.down_proj.weight',\n", + " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", + " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", + " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", + " ...,\n", + " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", + " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", + " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", + " ('model.layers.22.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.23.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.in_proj.weight',\n", + " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", + " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", + " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", + " ...,\n", + " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", + " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", + " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", + " ('model.layers.23.mixer.conv1d.weight',\n", + " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", + " \n", + " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", + " \n", + " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", + " \n", + " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", + " \n", + " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", + " ('model.layers.23.mixer.conv1d.bias',\n", + " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", + " ('model.layers.23.mixer.out_proj.weight',\n", + " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", + " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", + " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", + " ...,\n", + " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", + " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", + " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", + " ('model.layers.23.mlp.gate_proj.weight',\n", + " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", + " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", + " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", + " ...,\n", + " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", + " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", + " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", + " ('model.layers.23.mlp.up_proj.weight',\n", + " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", + " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", + " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", + " ...,\n", + " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", + " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", + " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", + " ('model.layers.23.mlp.down_proj.weight',\n", + " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", + " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", + " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", + " ...,\n", + " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", + " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", + " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", + " ('model.layers.23.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.24.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.in_proj.weight',\n", + " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", + " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", + " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", + " ...,\n", + " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", + " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", + " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", + " ('model.layers.24.mixer.conv1d.weight',\n", + " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", + " \n", + " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", + " \n", + " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", + " \n", + " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", + " \n", + " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", + " ('model.layers.24.mixer.conv1d.bias',\n", + " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", + " ('model.layers.24.mixer.out_proj.weight',\n", + " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", + " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", + " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", + " ...,\n", + " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", + " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", + " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", + " ('model.layers.24.mlp.gate_proj.weight',\n", + " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", + " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", + " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", + " ...,\n", + " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", + " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", + " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", + " ('model.layers.24.mlp.up_proj.weight',\n", + " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", + " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", + " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", + " ...,\n", + " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", + " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", + " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", + " ('model.layers.24.mlp.down_proj.weight',\n", + " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", + " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", + " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", + " ...,\n", + " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", + " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", + " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", + " ('model.layers.24.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.25.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.in_proj.weight',\n", + " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", + " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", + " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", + " ...,\n", + " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", + " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", + " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", + " ('model.layers.25.mixer.conv1d.weight',\n", + " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", + " \n", + " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", + " \n", + " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", + " \n", + " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", + " \n", + " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", + " ('model.layers.25.mixer.conv1d.bias',\n", + " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", + " ('model.layers.25.mixer.out_proj.weight',\n", + " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", + " -1.8566e-02, 1.3872e-02],\n", + " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", + " 7.3446e-03, -1.6368e-02],\n", + " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", + " -1.6554e-02, 7.2751e-03],\n", + " ...,\n", + " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", + " -1.0262e-02, -4.9292e-03],\n", + " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", + " 2.2522e-02, -6.0243e-03],\n", + " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", + " -5.7889e-03, 1.1855e-02]])),\n", + " ('model.layers.25.mlp.gate_proj.weight',\n", + " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", + " -1.2723e-02, -9.7756e-03],\n", + " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", + " -1.5727e-02, -9.8748e-03],\n", + " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", + " 1.2017e-04, -1.5734e-02],\n", + " ...,\n", + " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", + " 1.8795e-02, 1.0221e-02],\n", + " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", + " 8.5950e-05, 1.0542e-02],\n", + " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", + " -3.0686e-02, -9.6116e-03]])),\n", + " ('model.layers.25.mlp.up_proj.weight',\n", + " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", + " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", + " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", + " ...,\n", + " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", + " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", + " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", + " ('model.layers.25.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", + " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", + " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", + " ...,\n", + " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", + " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", + " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", + " ('model.layers.25.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.26.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.in_proj.weight',\n", + " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", + " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", + " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", + " ...,\n", + " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", + " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", + " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", + " ('model.layers.26.mixer.conv1d.weight',\n", + " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", + " \n", + " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", + " \n", + " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", + " \n", + " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", + " \n", + " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", + " ('model.layers.26.mixer.conv1d.bias',\n", + " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", + " ('model.layers.26.mixer.out_proj.weight',\n", + " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", + " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", + " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", + " ...,\n", + " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", + " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", + " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", + " ('model.layers.26.mlp.gate_proj.weight',\n", + " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", + " 1.3550e-02, -1.1203e-02],\n", + " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", + " 5.3891e-03, -2.8310e-02],\n", + " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", + " -2.8479e-03, -1.3105e-02],\n", + " ...,\n", + " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", + " -1.8015e-02, -1.2234e-02],\n", + " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", + " 1.7886e-02, -4.4699e-02],\n", + " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", + " 3.3468e-03, 2.7677e-02]])),\n", + " ('model.layers.26.mlp.up_proj.weight',\n", + " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", + " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", + " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", + " ...,\n", + " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", + " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", + " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", + " ('model.layers.26.mlp.down_proj.weight',\n", + " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", + " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", + " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", + " ...,\n", + " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", + " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", + " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", + " ('model.layers.26.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.27.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.in_proj.weight',\n", + " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", + " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", + " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", + " ...,\n", + " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", + " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", + " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", + " ('model.layers.27.mixer.conv1d.weight',\n", + " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", + " \n", + " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", + " \n", + " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", + " \n", + " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", + " \n", + " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", + " ('model.layers.27.mixer.conv1d.bias',\n", + " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", + " ('model.layers.27.mixer.out_proj.weight',\n", + " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", + " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", + " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", + " ...,\n", + " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", + " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", + " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", + " ('model.layers.27.mlp.gate_proj.weight',\n", + " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", + " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", + " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", + " ...,\n", + " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", + " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", + " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", + " ('model.layers.27.mlp.up_proj.weight',\n", + " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", + " 3.2845e-02, 2.1107e-02],\n", + " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", + " -1.1685e-02, 1.2504e-02],\n", + " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", + " -3.3912e-04, 3.0527e-02],\n", + " ...,\n", + " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", + " 1.9532e-02, 1.3339e-02],\n", + " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", + " 2.8262e-02, -1.2526e-02],\n", + " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", + " -2.0571e-03, -8.3943e-03]])),\n", + " ('model.layers.27.mlp.down_proj.weight',\n", + " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", + " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", + " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", + " ...,\n", + " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", + " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", + " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", + " ('model.layers.27.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('lm_head.weight',\n", + " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", + " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", + " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", + " ...,\n", + " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", + " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", + " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N params SSM: 5.305533088\n" + ] + } + ], + "source": [ + "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load State dict into SSM" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# apriel_ssm.state_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'apriel_ssm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" + ] + } + ], + "source": [ + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", + " save_config=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "24" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.model.layers[0].mixer.n_v_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Try a forward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", + " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", + " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", + " ...,\n", + " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", + " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", + " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", + " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", + " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", + " ...,\n", + " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", + " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", + " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load mdoel" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "import os\n", + "import shutil\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000\"\n", + "modeling_path = \"/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external\"\n", + "# # copy the config.json to the model path\n", + "shutil.copy(os.path.join(modeling_path, \"modeling_ssm_apriel.py\"), os.path.join(model_path, \"modeling_ssm_apriel.py\"))\n", + "shutil.copy(os.path.join(modeling_path, \"configuration_ssm_apriel.py\"), os.path.join(model_path, \"configuration_ssm_apriel.py\"))\n", + "\n", + "tokenizer_path = \"/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/\"\n", + "# # cp tokenizer*\n", + "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer.json\"), os.path.join(model_path, \"tokenizer.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer_config.json\"), os.path.join(model_path, \"tokenizer_config.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"special_tokens_map.json\"), os.path.join(model_path, \"special_tokens_map.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"vocab.json\"), os.path.join(model_path, \"vocab.json\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.08s/it]\n" + ] + } + ], + "source": [ + "\n", + "apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device=\"cuda\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "config = apriel_ssm.config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Mamba in Llama" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel\n", + "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridConfig {\n", + " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", + " \"architectures\": [\n", + " \"AprielForCausalLM\"\n", + " ],\n", + " \"attention_bias\": false,\n", + " \"attention_dropout\": 0.0,\n", + " \"auto_map\": {\n", + " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", + " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", + " },\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"max_position_embeddings\": 16384,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel\",\n", + " \"num_attention_heads\": 24,\n", + " \"num_hidden_layers\": 28,\n", + " \"num_key_value_heads\": 8,\n", + " \"pretraining_tp\": 1,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_scaling\": {\n", + " \"attention_factor\": null,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", + " \"factor\": 32.0,\n", + " \"original_max_position_embeddings\": 4096,\n", + " \"rope_type\": \"yarn\"\n", + " },\n", + " \"rope_theta\": 1000000.0,\n", + " \"ssm_block_pattern\": [\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\"\n", + " ],\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"d_inner\": 3072,\n", + " \"d_state\": 64,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 24,\n", + " \"n_v_heads\": 24\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", + " ssm_cfg=None)\n", + "hybrdif_apriel_config" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "28" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.num_hidden_layers" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (1): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (2): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (3): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (4): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (5): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (6): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (7): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (8): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (9): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (10): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (11): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (12): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (13): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (14): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (15): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (16): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (17): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (18): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (19): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (20): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (21): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (22): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (23): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (24): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (25): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (26): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (27): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BaseModelOutputWithPast(last_hidden_state=tensor([[[ 2.2031, 0.1777, 0.4258, ..., -2.0312, 0.2246, 0.5664],\n", + " [ 0.0562, -1.1016, 0.4590, ..., -2.1719, -0.1455, -0.6992],\n", + " [-1.5078, -1.3516, 0.8789, ..., -1.9141, 1.3672, -1.0391],\n", + " ...,\n", + " [-1.4453, 0.1260, 0.6992, ..., 0.4746, -0.1729, -0.5938],\n", + " [-0.4961, -0.4160, -0.4551, ..., -0.1328, 0.7461, -0.0376],\n", + " [ 0.3184, 0.4355, -0.7578, ..., 1.5547, 0.8555, -0.8711]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), past_key_values=DynamicCache(), hidden_states=None, attentions=None)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "hybrid_apriel_model.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.73it/s]\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m apriel_model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(checkpoint, torch_dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m apriel_state_dict \u001b[38;5;241m=\u001b[39m apriel_model\u001b[38;5;241m.\u001b[39mstate_dict()\n\u001b[0;32m----> 6\u001b[0m \u001b[43mapriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" + ] + } + ], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielConfig {\n", + " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", + " \"architectures\": [\n", + " \"AprielForCausalLM\"\n", + " ],\n", + " \"attention_bias\": false,\n", + " \"attention_dropout\": 0.0,\n", + " \"auto_map\": {\n", + " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", + " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", + " },\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"max_position_embeddings\": 16384,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel\",\n", + " \"num_attention_heads\": 24,\n", + " \"num_hidden_layers\": 28,\n", + " \"num_key_value_heads\": 8,\n", + " \"pretraining_tp\": 1,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_scaling\": {\n", + " \"attention_factor\": null,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", + " \"factor\": 32.0,\n", + " \"original_max_position_embeddings\": 4096,\n", + " \"rope_type\": \"yarn\"\n", + " },\n", + " \"rope_theta\": 1000000.0,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "d_xb = config.num_key_value_heads * config.head_dim\n", + "ssm_layers = [2,4,8]\n", + "attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]\n", + "model_name = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "ngroups = config.num_attention_heads # n heads\n", + "d_inner = config.head_dim * config.num_attention_heads\n", + "headdim = 128 # d_state\n", + "d_state = config.head_dim\n", + "d_model = config.hidden_size \n", + "assert d_inner == ngroups * d_state\n", + "\n", + "mamba_config = AprielSSMConfig(\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * headdim, # num_heads * head_dim\n", + " },\n", + " vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", + " rms_norm_eps=config.rms_norm_eps\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, \n", + " mamba_config, \n", + " attn_layers=attn_layers, \n", + " init_with_kqvo=True, \n", + " attn_implementation=\"flash_attention_2\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hymba2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py new file mode 100644 index 000000000..588918025 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py @@ -0,0 +1,446 @@ +import math +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + # Apriel: Use original max_position_embeddings instead of max_position_embeddings + max_position_embeddings = config.rope_scaling.get( + "original_max_position_embeddings", config.max_position_embeddings + ) + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "yarn": _compute_yarn_parameters, +} + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "yarn": _validate_yarn_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) + + +class AprielSSMHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AprielModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + ```python + >>> from transformers import AprielModel, AprielConfig + >>> # Initializing an Apriel Apriel-5B-Base style configuration + >>> configuration = AprielConfig() + >>> # Initializing a model from the Apriel-5B-Base style configuration + >>> model = AprielModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "apriel" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `AprielModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + ssm_block_pattern=["m2d"], + ssm_cfg=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.ssm_block_pattern = ssm_block_pattern + if len(ssm_block_pattern) == 1: + self.ssm_block_pattern = [ssm_block_pattern[0]] * self.num_hidden_layers + assert len(self.ssm_block_pattern) == self.num_hidden_layers + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py new file mode 100644 index 000000000..49b009866 --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py @@ -0,0 +1,1203 @@ +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.configuration_ssm_hybrid_apriel import ROPE_INIT_FUNCTIONS, AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class AprielRotaryEmbedding(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class AprielAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + out, _ = self.step(u, state) + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + inference_params=None, # just to be compatible with SSM block + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMHybridConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMHybridModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + blocks = [] + for layer_idx, type in enumerate(config.ssm_block_pattern): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) + elif type == "t": + blocks.append(AprielDecoderLayer(config, layer_idx)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + self.rotary_emb = AprielRotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + cache = {} + for i, layer in enumerate(self.layers): + if isinstance(layer, AprielSSMDecoderLayer): + cache[i] = layer.allocate_inference_cache(*args, **kwargs) + return cache + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + inference_params=None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + inference_params=inference_params, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] From 9a678df83ff6fe1a5a1f5b447b616836c0e6b5c3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 12:59:54 +0000 Subject: [PATCH 052/122] sft distill --- fast_llm/models/gpt/config.py | 6 +++--- fast_llm/models/ssm/config.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b82dd3e85..1ddb7ed20 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -193,9 +193,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index d77d206b0..1d8ac007c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -172,9 +172,7 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) - reference_models: dict[str, PretrainedGPTModelConfig] = ( - FieldUpdate() - ) # TODO: make sure any reference mdoel can be suported + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: @@ -190,9 +188,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From a7abe53383286f92fd26ad5ed93f11010e52e4c9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 14:21:32 +0000 Subject: [PATCH 053/122] conversion --- fast_llm/models/ssm/conversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index c2e54ca09..3d3aa7284 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -52,11 +52,11 @@ def _import_config(cls, config, architecture_only: bool = False): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: if cls.block_pattern is not None: - block_converter = MappedConfigParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - fast_llm_value=cls.block_pattern, - export_value=cls.block_pattern, + block_converter = ( + RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + ), ) else: block_converter = ConstantImportParamConverter( From a68c0b7318ec07abebea8acc0c035f9f8877017e Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 14:30:10 +0000 Subject: [PATCH 054/122] conversion --- fast_llm/models/ssm/conversion.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 3d3aa7284..675c709f8 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -52,11 +52,9 @@ def _import_config(cls, config, architecture_only: bool = False): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: if cls.block_pattern is not None: - block_converter = ( - RenameParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - ), + block_converter = RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), ) else: block_converter = ConstantImportParamConverter( From 9cfef449bb232e6c3895f7406f04e67b7a62fea9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 15:30:33 +0000 Subject: [PATCH 055/122] lr stage definition as string --- fast_llm/engine/optimizer/learning_rate.py | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index bf11038a5..c6912e4f1 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR begin_step = 0 for stage_arg_str in config.schedule.split(";"): try: - for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","): - assert begin_step is not None - num_steps = int(num_steps) - end_step = None if num_steps < 0 else begin_step + num_steps - kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} - if len(stage_args) > 0: - kwargs["end_lr"] = float(stage_args[0]) - if len(stage_args) > 1: - kwargs["power"] = float(stage_args[1]) - if len(stage_args) > 2: - raise ValueError(stage_args[2:]) - stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) - begin_step = end_step + stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",") + assert begin_step is not None + num_steps = int(num_steps) + end_step = None if num_steps < 0 else begin_step + num_steps + kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} + if len(stage_args) > 0: + kwargs["end_lr"] = float(stage_args[0]) + if len(stage_args) > 1: + kwargs["power"] = float(stage_args[1]) + if len(stage_args) > 2: + raise ValueError(stage_args[2:]) + stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) + begin_step = end_step except Exception: raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"') return LearningRateSchedule(stages) From 005e623e08936f7addad87a083fde3aab176ceef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 11:59:15 -0400 Subject: [PATCH 056/122] fixes --- fast_llm/functional/triton/mlp.py | 12 +++++++----- fast_llm/layers/ssm/llamba_block.py | 2 +- fast_llm/layers/transformer/transformer.py | 3 +-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 5b220b1ac..ee3ba304c 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,19 +50,21 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu.value: + if activation_type == _TritonActivationType.gelu: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == _TritonActivationType.silu.value: + elif activation_type == _TritonActivationType.silu: out = input_ / (1 + tl.exp(-input_)) - elif activation_type == _TritonActivationType.relu.value: + elif activation_type == _TritonActivationType.relu: out = tl.where(input_ > 0, input_, 0) elif activation_type == _TritonActivationType.squared_relu: relu_out = tl.where(input_ > 0, input_, 0) out = relu_out * relu_out + elif activation_type == _TritonActivationType.identity: + out = input_ else: - raise NotImplementedError() + tl.static_assert(False, activation_type) if gated: other = tl.load(input_ptr + n_cols, mask=mask) @@ -124,7 +126,7 @@ def triton_mlp_activation_backward_kernel( if gated or recompute: out = input_ else: - raise NotImplementedError() + tl.static_assert(False, activation_type) if gated: other = tl.load(input_ptr + n_cols, mask=mask) diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 221356389..ee222d6d2 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -13,7 +13,7 @@ class LlambaBlock(BaseBlock): A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ - name = "Llamba block" + _name = "Llamba block" _mixer_module_name = "mixer" def __init__( diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 92df18937..40dd2e00e 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -23,7 +23,6 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block block with abstract mixer. """ - name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( @@ -137,7 +136,7 @@ def forward( class TransformerLayer(BaseBlock): - name = "Transformer layer" + _name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( From cad951aafd0f2ac8048929c8fea5b5b67c0f6a42 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 12:31:07 -0400 Subject: [PATCH 057/122] fix --- fast_llm/layers/language_model/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4fb471fb3..b4b4e187c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,8 +215,6 @@ def _validate(self) -> None: super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) - if self.prediction_heads > 1: - Assert.gt(self.transformer.num_layers, 1) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") From bce916d01566541ee364bcd2fd395db8f4010ff6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 17:07:43 +0000 Subject: [PATCH 058/122] loss maks --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 34c69d797..401cfe073 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -187,9 +187,9 @@ def cross_entropy_forward_backward( if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( - logits, target, grad_output, logits_scale_factor, target_format, group + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format ) From 9d9506418cc290c549af6aa6c9b8040f0ea7e1e5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 13:14:02 -0400 Subject: [PATCH 059/122] fix --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 34c69d797..401cfe073 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -187,9 +187,9 @@ def cross_entropy_forward_backward( if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( - logits, target, grad_output, logits_scale_factor, target_format, group + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format ) From 935c470b4ade91462a92cfe26aae9be1e32e2154 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 13:30:28 -0400 Subject: [PATCH 060/122] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f3633a76a..065eb94d8 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -187,7 +187,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) + self._load_yaml_data(loaded_yaml_data) if not self._truncate_documents: del loaded_yaml_data["unshuffled_tokens"] From 9aff3b70cb13d57920b98d4d1b65926fdea5fa5f Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 17:39:10 +0000 Subject: [PATCH 061/122] fix shuffled tokens --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f3633a76a..065eb94d8 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -187,7 +187,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) + self._load_yaml_data(loaded_yaml_data) if not self._truncate_documents: del loaded_yaml_data["unshuffled_tokens"] From ae4d111a26105362ec4f49aeea8da77e522f98c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 15:08:40 -0400 Subject: [PATCH 062/122] fixes --- tests/layers/test_lm_head.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 14edecffd..b32292bdd 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -126,15 +126,13 @@ def test_lm_head( else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) ) if loss_masking: - loss_mask = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.bool, - device=distributed.device, - ) + loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) else: loss_mask = None + kwargs = { + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.grad_output: 1.0, + } if config.distillation_model is None: target = torch.randint( 0, @@ -145,14 +143,17 @@ def test_lm_head( ) if loss_mask is not None: target *= loss_mask + + kwargs[LanguageModelKwargs.labels] = target else: assert config.prediction_heads == 1 - target = torch.randn_like(input_) - kwargs = { - TransformerKwargs.sequence_first: sequence_first, - LanguageModelKwargs.labels: target, - TransformerKwargs.grad_output: 1.0, - } + target = torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=distributed.device, + ) + kwargs[f"{config.distillation_model}_logits"] = target + if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( torch.empty( From deb7ce66a6bb15e60a493845a71b4f6ee9366ea2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 18:00:13 -0400 Subject: [PATCH 063/122] fixes --- fast_llm/functional/cross_entropy.py | 12 +++++++----- fast_llm/functional/triton/cross_entropy.py | 16 ++++++++++++---- fast_llm/layers/language_model/head.py | 2 ++ tests/layers/test_lm_head.py | 10 +++++++--- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 401cfe073..513510ec7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -38,7 +38,7 @@ def _torch_cross_entropy_forward_backward( torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" ) - * loss_mask.unsqueeze(-1) + * loss_mask ).mean() if grad_output is None: grad = None @@ -48,7 +48,7 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad -# @torch.compile +@torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -74,7 +74,7 @@ def _fused_softmax( return exp_logits / sum_exp_logits -@torch.compile +# @torch.compile def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -113,6 +113,8 @@ def _fused_cross_entropy_forward_backward( else: # Target should be tensor-parallel already, no further manipulation needed. target_mask = None + if loss_mask is not None: + loss_mask = loss_mask.unsqueeze(-1) if grad_output is None: grad = None @@ -128,9 +130,9 @@ def _fused_cross_entropy_forward_backward( grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: grad *= logits_scale_factor - grad = grad.to(logits.dtype) if loss_mask is not None: - grad = torch.where(loss_mask, grad.to(logits.dtype), 0) + grad *= loss_mask + grad = grad.to(logits.dtype) # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) if target_format == TargetFormat.labels: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 02dc1ce78..8cb59c85c 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -64,7 +64,6 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, - loss_mask_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -75,6 +74,14 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( col_offsets = tl.arange(0, block_size) mask = col_offsets < n_cols + if loss_mask_ptr is not None: + loss_mask = tl.load(loss_mask_ptr + block_idx) + if loss_mask == 0: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=mask) + return + logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) @@ -89,8 +96,6 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) - if loss_mask_ptr is not None: - loss_mask = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=0) if from_logits: if logits_scale_factor != 1.0: target *= logits_scale_factor @@ -108,6 +113,8 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) if logits_scale_factor != 1.0: grad_logits *= logits_scale_factor + if loss_mask_ptr is not None: + grad_logits = grad_logits tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) @@ -151,6 +158,8 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, ) else: + if loss_mask is not None: + assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, @@ -161,7 +170,6 @@ def triton_cross_entropy_forward_backward( n_cols, logits.stride(0), target.stride(0), - None if loss_mask is None else loss_mask.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 9b1dd4d8a..813dcc076 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -163,6 +163,8 @@ def _forward_backward( # Target is reference model logits. target = target.flatten(0, -2) loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() if self._sequence_parallel_logits: target = split_op(target, self._tensor_space.distributed.tensor_group, 0) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b32292bdd..7578a5f05 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -40,14 +40,16 @@ def _lm_head( rms_weight, 1e-5, ) - logits = torch.nn.functional.linear(hidden, logit_weight) + logits = torch.nn.functional.linear(hidden, logit_weight).float() if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None if target.ndim == logits.ndim: - loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") + loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" + ) if loss_mask is not None: - loss = loss * loss_mask.unsqueeze(-1) + loss = loss * loss_mask.flatten() loss = loss.mean() else: loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) @@ -153,6 +155,8 @@ def test_lm_head( device=distributed.device, ) kwargs[f"{config.distillation_model}_logits"] = target + if loss_mask is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( From eaba34f66730d06e880b209f97f0560eead0e510 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 22:01:04 +0000 Subject: [PATCH 064/122] innit like in mamba in llama --- .../models/ssm/external/ariel_to_ssm.ipynb | 963 ++++-------------- 1 file changed, 213 insertions(+), 750 deletions(-) diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 496338cb0..664d927fa 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -29,6 +29,13 @@ "%autoreload 2\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Apriel SSM for distillation" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -115,35 +122,6 @@ "apriel_model.config.torch_dtype" ] }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.83207168" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "n_params/1e9" - ] - }, { "cell_type": "code", "execution_count": 8, @@ -161,62 +139,6 @@ "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" ] }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'n_qk_heads'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" - ] - } - ], - "source": [ - "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", - "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps)" - ] - }, { "cell_type": "code", "execution_count": 10, @@ -2330,15 +2252,6 @@ "apriel_ssm.to(device).to(dtype=torch.bfloat16)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# apriel_ssm.state_dict()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2503,20 +2416,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Load mdoel" + "## Load Apriel SSM into HF class" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 130, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -2632,12 +2545,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Mamba in Llama" + "# Mamba in Llama: SSM hybrid " ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 90, "metadata": {}, "outputs": [ { @@ -2660,7 +2573,7 @@ "from transformers.cache_utils import StaticCache\n", "from types import SimpleNamespace\n", "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer\n", "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", "# make sure the code changes reflected without reload\n", "%load_ext autoreload\n", @@ -2669,146 +2582,104 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 81, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMHybridConfig {\n", - " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", - " \"architectures\": [\n", - " \"AprielForCausalLM\"\n", - " ],\n", - " \"attention_bias\": false,\n", - " \"attention_dropout\": 0.0,\n", - " \"auto_map\": {\n", - " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", - " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", - " },\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"head_dim\": 128,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"max_position_embeddings\": 16384,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel\",\n", - " \"num_attention_heads\": 24,\n", - " \"num_hidden_layers\": 28,\n", - " \"num_key_value_heads\": 8,\n", - " \"pretraining_tp\": 1,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"rope_scaling\": {\n", - " \"attention_factor\": null,\n", - " \"beta_fast\": 32.0,\n", - " \"beta_slow\": 1.0,\n", - " \"factor\": 32.0,\n", - " \"original_max_position_embeddings\": 4096,\n", - " \"rope_type\": \"yarn\"\n", - " },\n", - " \"rope_theta\": 1000000.0,\n", - " \"ssm_block_pattern\": [\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\"\n", - " ],\n", - " \"ssm_cfg\": {\n", - " \"activation\": \"identity\",\n", - " \"bias\": false,\n", - " \"chunk_size\": 128,\n", - " \"d_inner\": 3072,\n", - " \"d_state\": 64,\n", - " \"expand\": 1,\n", - " \"n_qk_heads\": 24,\n", - " \"n_v_heads\": 24\n", - " },\n", - " \"tie_word_embeddings\": false,\n", - " \"torch_dtype\": \"bfloat16\",\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "\n", "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "\n", + "# d_xb = config.num_key_value_heads * config.head_dim\n", + "d_inner = config.num_attention_heads * config.head_dim\n", + "d_state = config.head_dim\n", "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", - " ssm_cfg=None)\n", - "hybrdif_apriel_config" + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " # \"d_xb\": d_xb,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })\n", + "# hybrdif_apriel_config" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "28" + "AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + ")" ] }, - "execution_count": 15, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config.num_hidden_layers" + "hybrid_apriel_model.layers[0]" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 91, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + "isinstance(hybrid_apriel_model.layers[0], AprielSSMDecoderLayer)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device = \"cpu\" #if torch.cuda.is_available() else \"cpu\"\n", "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", "batch_size = 1\n", "max_length = 128\n", @@ -2824,472 +2695,24 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 73, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "AprielSSMHybridModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (1): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (2): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (3): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (4): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (5): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (6): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (7): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (8): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (9): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (10): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (11): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (12): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (13): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (14): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (15): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (16): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (17): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (18): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (19): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (20): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (21): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (22): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (23): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (24): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (25): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (26): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (27): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - ")" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" + "ename": "OutOfMemoryError", + "evalue": "CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[73], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)" + ] } ], "source": [ @@ -3298,25 +2721,28 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 79, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "BaseModelOutputWithPast(last_hidden_state=tensor([[[ 2.2031, 0.1777, 0.4258, ..., -2.0312, 0.2246, 0.5664],\n", - " [ 0.0562, -1.1016, 0.4590, ..., -2.1719, -0.1455, -0.6992],\n", - " [-1.5078, -1.3516, 0.8789, ..., -1.9141, 1.3672, -1.0391],\n", - " ...,\n", - " [-1.4453, 0.1260, 0.6992, ..., 0.4746, -0.1729, -0.5938],\n", - " [-0.4961, -0.4160, -0.4551, ..., -0.1328, 0.7461, -0.0376],\n", - " [ 0.3184, 0.4355, -0.7578, ..., 1.5547, 0.8555, -0.8711]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=), past_key_values=DynamicCache(), hidden_states=None, attentions=None)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" + "ename": "RuntimeError", + "evalue": "split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[79], line 2\u001b[0m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstatic_inputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:1043\u001b[0m, in \u001b[0;36mAprielSSMHybridModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, inference_params, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[1;32m 1042\u001b[0m all_hidden_states \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (hidden_states,)\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1056\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(decoder_layer, AprielDecoderLayer):\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:805\u001b[0m, in \u001b[0;36mAprielSSMDecoderLayer.forward\u001b[0;34m(self, hidden_states, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 803\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[0;32m--> 805\u001b[0m mixer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 807\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 810\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m mixer_outputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto(residual\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;241m+\u001b[39m residual\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:460\u001b[0m, in \u001b[0;36mDiscreteMamba2.forward\u001b[0;34m(self, u, return_mixer_matrix, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# Project input\u001b[39;00m\n\u001b[1;32m 459\u001b[0m xBCzA_log \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[0;32m--> 460\u001b[0m xBC, z, A_log \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 461\u001b[0m \u001b[43m \u001b[49m\u001b[43mxBCzA_log\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 462\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 463\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_qk_heads\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 464\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_v_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 467\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\u001b[39;00m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\u001b[39;00m\n\u001b[1;32m 473\u001b[0m xBC_t \u001b[38;5;241m=\u001b[39m rearrange(xBC[:, :seqlen, :], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/functional.py:196\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(tensor, split_size_or_sections, dim)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 191\u001b[0m split, (tensor,), tensor, split_size_or_sections, dim\u001b[38;5;241m=\u001b[39mdim)\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# Overwriting reason:\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;66;03m# This dispatches to two ATen functions depending on the type of\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# split_size_or_sections. The branching code is in _tensor.py, which we\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# call here.\u001b[39;00m\n\u001b[0;32m--> 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_size_or_sections\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/_tensor.py:917\u001b[0m, in \u001b[0;36mTensor.split\u001b[0;34m(self, split_size, dim)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_VF\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m, split_size, dim) \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[1;32m 916\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 917\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit_with_sizes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]" + ] } ], "source": [ @@ -3326,102 +2752,139 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.73it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.44it/s]\n" ] }, { - "ename": "RuntimeError", - "evalue": "CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m apriel_model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(checkpoint, torch_dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m apriel_state_dict \u001b[38;5;241m=\u001b[39m apriel_model\u001b[38;5;241m.\u001b[39mstate_dict()\n\u001b[0;32m----> 6\u001b[0m \u001b[43mapriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" - ] + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)\n" + "apriel_model.to(device).to(dtype=torch.bfloat16)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "# Innitialization using k, q, v from Apriel transformer\n", + "def expand_k_q(k):\n", + " Hq = config.num_attention_heads\n", + " Hk = config.num_key_value_heads\n", + " d_head = config.head_dim\n", + " d = k.shape[-1]\n", + " \n", + " # Expand k\n", + " repeat_factor = Hq // Hk\n", + " k_expanded = k.view(Hk, d_head, d)\n", + " k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)\n", + " k_expanded = k_expanded.view(d_head * Hq, d)\n", + " return k_expanded\n", + "\n", + "for block_h, block_t in zip(hybrid_apriel_model.layers, apriel_model.model.layers):\n", + " # print(isinstance(block_h, AprielSSMDecoderLayer))\n", + " if isinstance(block_h, AprielSSMDecoderLayer):\n", + " # print(block_h.mixer.n_v_heads)\n", + " # print(block_t.self_attn.v_proj.weight.shape)\n", + " # print(block_h.mixer.in_proj.weight.shape)\n", + "\n", + " # print(block_h.mixer.in_proj.weight.shape)\n", + " # print(block_t.self_attn.v_proj.weight.shape)\n", + " block_h.mlp.load_state_dict(block_t.mlp.state_dict())\n", + " block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())\n", + " block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())\n", + " block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())\n", + " # [x B C z A_log]\n", + " # print(block_h.mixer.d_inner)\n", + " # init x, but interleave to address GQA\n", + " v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)\n", + " block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)\n", + " # init k, but interleave to address GQA\n", + " k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)\n", + " block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)\n", + " # init C ewith Q\n", + " block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AprielConfig {\n", - " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", - " \"architectures\": [\n", - " \"AprielForCausalLM\"\n", - " ],\n", - " \"attention_bias\": false,\n", - " \"attention_dropout\": 0.0,\n", - " \"auto_map\": {\n", - " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", - " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", - " },\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"head_dim\": 128,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"max_position_embeddings\": 16384,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel\",\n", - " \"num_attention_heads\": 24,\n", - " \"num_hidden_layers\": 28,\n", - " \"num_key_value_heads\": 8,\n", - " \"pretraining_tp\": 1,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"rope_scaling\": {\n", - " \"attention_factor\": null,\n", - " \"beta_fast\": 32.0,\n", - " \"beta_slow\": 1.0,\n", - " \"factor\": 32.0,\n", - " \"original_max_position_embeddings\": 4096,\n", - " \"rope_type\": \"yarn\"\n", - " },\n", - " \"rope_theta\": 1000000.0,\n", - " \"tie_word_embeddings\": false,\n", - " \"torch_dtype\": \"bfloat16\",\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" + "torch.Size([1024, 4096])" ] }, - "execution_count": 4, + "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config" + "block_t.self_attn.v_proj.weight.data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#" ] }, { From f8ca1222a1f938a77338272d6c7f32bb9769d1a9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 5 May 2025 01:22:06 +0000 Subject: [PATCH 065/122] embeddings_lr_scale --- fast_llm/layers/language_model/config.py | 7 +++++++ fast_llm/layers/language_model/embedding.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0371eff43..1b7e2d945 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -203,6 +203,13 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def _validate(self) -> None: self.transformer.validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..e0386d8df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. From 2db740bcc1d23530a8db15b794b2f6778d8a61ca Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 5 May 2025 13:15:02 -0400 Subject: [PATCH 066/122] fix --- fast_llm/models/gpt/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e5afac160..418f948e3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -179,9 +179,6 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From 41d4da3491faa7ac7c1a6bd599be4fa41b97feeb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 5 May 2025 21:50:05 +0000 Subject: [PATCH 067/122] disable freezing --- fast_llm/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 849307563..611eb9f48 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,9 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # TODO: re-enable when fixed? + # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 4160b1f3a189174502f27942d9f9ff995f34eb4d Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 6 May 2025 13:24:29 +0000 Subject: [PATCH 068/122] hybrid model loading and exporting --- fast_llm/models/ssm/config.py | 22 +- fast_llm/models/ssm/conversion.py | 247 +- .../configuration_ssm_apriel.py | 3 +- .../{ => aperiel_ssm}/modeling_ssm_apriel.py | 8 +- .../configuration_ssm_hybrid_apriel.py | 12 +- .../modeling_ssm_hybrid_apriel.py | 14 +- .../models/ssm/external/ariel_to_ssm.ipynb | 2989 ----------------- .../ssm/external/eval/apriel_eval_wrapper.py | 57 +- .../{ => llamba}/configuration_mtp_llamba.py | 0 .../{ => llamba}/modeling_mtp_llamba.py | 0 tests/test_ssms.py | 119 +- 11 files changed, 365 insertions(+), 3106 deletions(-) rename fast_llm/models/ssm/external/{ => aperiel_ssm}/configuration_ssm_apriel.py (95%) rename fast_llm/models/ssm/external/{ => aperiel_ssm}/modeling_ssm_apriel.py (98%) rename fast_llm/models/ssm/external/{ => apriel_hybrid}/configuration_ssm_hybrid_apriel.py (98%) rename fast_llm/models/ssm/external/{ => apriel_hybrid}/modeling_ssm_hybrid_apriel.py (99%) delete mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb rename fast_llm/models/ssm/external/{ => llamba}/configuration_mtp_llamba.py (100%) rename fast_llm/models/ssm/external/{ => llamba}/modeling_mtp_llamba.py (100%) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 1d8ac007c..44093d1fd 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -23,12 +23,18 @@ class HybridSSMArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False - hybrid_block_layout: list[str] = Field( - default_factory=lambda: [SSMBlockType.mamba2_discrete.value], + hybrid_block_layout: list[str] | None = Field( + default=None, desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) + def _validate(self): + if self.hybrid_block_layout is None: + with self._set_implicit_default(): + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + super()._validate() + @config_class() class HybridSSMBaseModelConfig(LanguageModelBaseConfig, HybridSSMArchitectureConfig): @@ -133,6 +139,17 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler +class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler + + return AprielSSMHHybridHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False @@ -141,6 +158,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( LLambaHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, + AprielSSMHHybridHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 675c709f8..2e84cd10e 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -19,8 +19,9 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import MLPLayer2Converter +from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, @@ -72,10 +73,6 @@ class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), RenameParamConverter( fast_llm_names=(("ssm", "state_size"),), export_names=( @@ -143,12 +140,79 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + + num_layers = self._model.config.base_model.transformer.num_layers + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"model.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"model.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + return converters + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + -class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value + _hf_prefix: str = "backbone" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -156,6 +220,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json """ return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), @@ -208,6 +276,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: + # not using super() because LLamba model is called backbone in the checkpoints converters = [] num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False @@ -215,58 +284,68 @@ def _create_weight_converters(self) -> list[WeightConverter]: # Embedding and output if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append( + WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + ) # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias ) for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{i+1}.mixer.z_bias", + f"{self._hf_prefix}.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # Norm converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias ) # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") return converters @@ -330,14 +409,22 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + """ + Lamba-like configs, pure SSM models. + """ + + _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -377,10 +464,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] + converters = super()._create_weight_converters() num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear # Embedding and output if self._model.config.base_model.tie_word_embeddings: @@ -396,36 +482,6 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - # Norm converters += self._get_weight_and_bias_converters( f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias @@ -456,33 +512,62 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py similarity index 95% rename from fast_llm/models/ssm/external/configuration_ssm_apriel.py rename to fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py index 2e5d5810c..c3f7ef38d 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py @@ -96,7 +96,8 @@ def __init__( "bias": False, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - assert self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"] + if self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is equal to d_inner // n_qk_heads.") __all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py similarity index 98% rename from fast_llm/models/ssm/external/modeling_ssm_apriel.py rename to fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py index 5a1b8db42..dd228024c 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -172,7 +172,7 @@ def __init__( **factory_kwargs, ) self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 ) # make sure z_bias always exists # Convolutional layer @@ -197,7 +197,7 @@ def __init__( raise ValueError(f"Unknown activation {self.activation}") # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) self.D._optim = {"weight_decay": 0.0} # out_proj @@ -670,7 +670,7 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): def __init__(self, config, device=None, dtype=None, **kwargs): super().__init__(config, device=device, dtype=dtype, **kwargs) - self.model = AprielSSMModel(config) + self.model = AprielSSMModel(config, device=device, dtype=dtype) self.vocab_size = config.vocab_size factory_kwargs = {"device": device, "dtype": dtype} self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) diff --git a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py similarity index 98% rename from fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py rename to fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py index 588918025..b030150ce 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -344,7 +344,7 @@ class AprielSSMHybridConfig(PretrainedConfig): >>> configuration = model.config ```""" - model_type = "apriel" + model_type = "apriel_ssm_hybrid" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `AprielModel` base_model_tp_plan = { @@ -386,7 +386,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, head_dim=None, - ssm_block_pattern=["m2d"], + hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs, ): @@ -413,10 +413,10 @@ def __init__( self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - self.ssm_block_pattern = ssm_block_pattern - if len(ssm_block_pattern) == 1: - self.ssm_block_pattern = [ssm_block_pattern[0]] * self.num_hidden_layers - assert len(self.ssm_block_pattern) == self.num_hidden_layers + self.hybrid_block_layout = hybrid_block_layout + if len(hybrid_block_layout) == 1: + self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers + assert len(self.hybrid_block_layout) == self.num_hidden_layers # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: diff --git a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py rename to fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 49b009866..950327df9 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -21,7 +21,10 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.configuration_ssm_hybrid_apriel import ROPE_INIT_FUNCTIONS, AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( + ROPE_INIT_FUNCTIONS, + AprielSSMHybridConfig, +) logger = logging.get_logger(__name__) @@ -875,7 +878,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa factory_kwargs = {"device": device, "dtype": dtype} self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) blocks = [] - for layer_idx, type in enumerate(config.ssm_block_pattern): + for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) elif type == "t": @@ -1169,11 +1172,12 @@ def forward( **kwargs: Unpack[KwargsForCausalLM], ) -> Union[tuple, CausalLMOutputWithPast]: - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, return_hidden_states=return_hidden_states, inference_params=inference_params, position_ids=position_ids, + return_dict=True, ) if outputs["last_hidden_state"] is not None and return_logits: @@ -1185,8 +1189,8 @@ def forward( return CustomMambaCausalLMOutput( loss=None, logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=outputs["last_hidden_state"], + all_hidden_states=outputs.hidden_states, + last_hidden_state=outputs.last_hidden_state, ) def generate(self, *args, **kwargs): diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb deleted file mode 100644 index 664d927fa..000000000 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ /dev/null @@ -1,2989 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Apriel SSM for distillation" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.bfloat16" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_model.config.torch_dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" - ] - } - ], - "source": [ - "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('model.embed_tokens.weight',\n", - " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", - " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", - " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", - " ...,\n", - " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", - " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", - " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", - " ('model.layers.0.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.0.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.0.mixer.in_proj.weight',\n", - " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", - " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", - " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", - " ...,\n", - " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", - " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", - " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", - " ('model.layers.0.mixer.conv1d.weight',\n", - " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", - " \n", - " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", - " \n", - " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", - " \n", - " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", - " \n", - " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", - " ('model.layers.0.mixer.conv1d.bias',\n", - " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", - " ('model.layers.0.mixer.out_proj.weight',\n", - " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", - " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", - " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", - " ...,\n", - " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", - " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", - " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", - " ('model.layers.0.mlp.gate_proj.weight',\n", - " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", - " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", - " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", - " ...,\n", - " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", - " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", - " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", - " ('model.layers.0.mlp.up_proj.weight',\n", - " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", - " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", - " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", - " ...,\n", - " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", - " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", - " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", - " ('model.layers.0.mlp.down_proj.weight',\n", - " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", - " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", - " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", - " ...,\n", - " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", - " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", - " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", - " ('model.layers.0.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.0.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.1.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.in_proj.weight',\n", - " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", - " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", - " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", - " ...,\n", - " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", - " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", - " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", - " ('model.layers.1.mixer.conv1d.weight',\n", - " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", - " \n", - " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", - " \n", - " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", - " \n", - " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", - " \n", - " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", - " ('model.layers.1.mixer.conv1d.bias',\n", - " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", - " ('model.layers.1.mixer.out_proj.weight',\n", - " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", - " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", - " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", - " ...,\n", - " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", - " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", - " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", - " ('model.layers.1.mlp.gate_proj.weight',\n", - " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", - " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", - " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", - " ...,\n", - " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", - " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", - " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", - " ('model.layers.1.mlp.up_proj.weight',\n", - " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", - " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", - " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", - " ...,\n", - " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", - " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", - " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", - " ('model.layers.1.mlp.down_proj.weight',\n", - " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", - " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", - " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", - " ...,\n", - " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", - " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", - " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", - " ('model.layers.1.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.2.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.in_proj.weight',\n", - " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", - " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", - " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", - " ...,\n", - " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", - " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", - " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", - " ('model.layers.2.mixer.conv1d.weight',\n", - " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", - " \n", - " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", - " \n", - " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", - " \n", - " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", - " \n", - " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", - " ('model.layers.2.mixer.conv1d.bias',\n", - " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", - " ('model.layers.2.mixer.out_proj.weight',\n", - " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", - " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", - " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", - " ...,\n", - " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", - " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", - " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", - " ('model.layers.2.mlp.gate_proj.weight',\n", - " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", - " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", - " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", - " ...,\n", - " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", - " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", - " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", - " ('model.layers.2.mlp.up_proj.weight',\n", - " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", - " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", - " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", - " ...,\n", - " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", - " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", - " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", - " ('model.layers.2.mlp.down_proj.weight',\n", - " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", - " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", - " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", - " ...,\n", - " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", - " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", - " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", - " ('model.layers.2.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.3.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.in_proj.weight',\n", - " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", - " -2.1648e-03, 2.8147e-03],\n", - " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", - " 4.9989e-03, 6.4200e-03],\n", - " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", - " 9.7000e-03, 3.1136e-05],\n", - " ...,\n", - " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", - " 4.6152e-04, 5.5341e-02],\n", - " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", - " -1.0645e-02, -1.8442e-03],\n", - " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", - " -4.7119e-03, 3.2352e-02]])),\n", - " ('model.layers.3.mixer.conv1d.weight',\n", - " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", - " \n", - " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", - " \n", - " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", - " \n", - " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", - " \n", - " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", - " ('model.layers.3.mixer.conv1d.bias',\n", - " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", - " ('model.layers.3.mixer.out_proj.weight',\n", - " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", - " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", - " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", - " ...,\n", - " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", - " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", - " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", - " ('model.layers.3.mlp.gate_proj.weight',\n", - " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", - " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", - " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", - " ...,\n", - " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", - " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", - " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", - " ('model.layers.3.mlp.up_proj.weight',\n", - " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", - " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", - " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", - " ...,\n", - " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", - " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", - " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", - " ('model.layers.3.mlp.down_proj.weight',\n", - " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", - " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", - " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", - " ...,\n", - " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", - " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", - " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", - " ('model.layers.3.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.4.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.in_proj.weight',\n", - " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", - " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", - " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", - " ...,\n", - " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", - " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", - " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", - " ('model.layers.4.mixer.conv1d.weight',\n", - " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", - " \n", - " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", - " \n", - " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", - " \n", - " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", - " \n", - " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", - " ('model.layers.4.mixer.conv1d.bias',\n", - " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", - " ('model.layers.4.mixer.out_proj.weight',\n", - " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", - " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", - " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", - " ...,\n", - " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", - " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", - " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", - " ('model.layers.4.mlp.gate_proj.weight',\n", - " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", - " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", - " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", - " ...,\n", - " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", - " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", - " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", - " ('model.layers.4.mlp.up_proj.weight',\n", - " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", - " 3.6365e-02, -8.1017e-03],\n", - " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", - " -3.1505e-02, 2.1047e-03],\n", - " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", - " -2.7856e-04, -8.7440e-04],\n", - " ...,\n", - " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", - " -1.9620e-02, 2.5112e-02],\n", - " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", - " -1.6838e-03, 2.0189e-02],\n", - " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", - " -1.3502e-02, -1.8923e-02]])),\n", - " ('model.layers.4.mlp.down_proj.weight',\n", - " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", - " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", - " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", - " ...,\n", - " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", - " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", - " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", - " ('model.layers.4.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.5.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.in_proj.weight',\n", - " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", - " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", - " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", - " ...,\n", - " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", - " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", - " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", - " ('model.layers.5.mixer.conv1d.weight',\n", - " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", - " \n", - " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", - " \n", - " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", - " \n", - " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", - " \n", - " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", - " ('model.layers.5.mixer.conv1d.bias',\n", - " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", - " ('model.layers.5.mixer.out_proj.weight',\n", - " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", - " 7.0706e-03, -4.7970e-03],\n", - " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", - " 2.2802e-02, -2.9729e-03],\n", - " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", - " -4.2582e-03, 2.4007e-02],\n", - " ...,\n", - " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", - " -2.0391e-02, 5.3912e-03],\n", - " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", - " -2.0596e-02, -8.2915e-03],\n", - " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", - " -3.8566e-02, 2.7005e-02]])),\n", - " ('model.layers.5.mlp.gate_proj.weight',\n", - " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", - " 3.6159e-02, -1.7253e-02],\n", - " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", - " 4.4487e-03, -9.6723e-03],\n", - " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", - " 1.2872e-02, -1.0443e-02],\n", - " ...,\n", - " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", - " 2.8752e-02, -8.0048e-03],\n", - " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", - " -3.2179e-02, -1.6437e-02],\n", - " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", - " 7.7639e-03, -2.2494e-03]])),\n", - " ('model.layers.5.mlp.up_proj.weight',\n", - " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", - " -1.8538e-02, 6.7395e-03],\n", - " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", - " -3.0374e-02, 2.5795e-02],\n", - " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", - " 1.4115e-02, -2.6438e-02],\n", - " ...,\n", - " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", - " -4.9858e-02, -1.1575e-02],\n", - " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", - " 1.8568e-02, 5.2615e-03],\n", - " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", - " 1.2643e-02, -1.1568e-05]])),\n", - " ('model.layers.5.mlp.down_proj.weight',\n", - " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", - " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", - " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", - " ...,\n", - " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", - " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", - " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", - " ('model.layers.5.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.6.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.in_proj.weight',\n", - " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", - " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", - " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", - " ...,\n", - " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", - " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", - " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", - " ('model.layers.6.mixer.conv1d.weight',\n", - " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", - " \n", - " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", - " \n", - " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", - " \n", - " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", - " \n", - " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", - " ('model.layers.6.mixer.conv1d.bias',\n", - " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", - " ('model.layers.6.mixer.out_proj.weight',\n", - " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", - " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", - " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", - " ...,\n", - " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", - " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", - " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", - " ('model.layers.6.mlp.gate_proj.weight',\n", - " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", - " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", - " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", - " ...,\n", - " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", - " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", - " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", - " ('model.layers.6.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", - " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", - " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", - " ...,\n", - " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", - " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", - " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", - " ('model.layers.6.mlp.down_proj.weight',\n", - " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", - " 4.8023e-03, 9.8474e-03],\n", - " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", - " -9.5275e-04, -1.8121e-02],\n", - " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", - " 2.8131e-03, 5.8219e-03],\n", - " ...,\n", - " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", - " -8.3965e-03, 2.0074e-02],\n", - " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", - " 3.4947e-02, -2.2158e-02],\n", - " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", - " 4.0735e-05, 3.3651e-02]])),\n", - " ('model.layers.6.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.7.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.in_proj.weight',\n", - " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", - " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", - " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", - " ...,\n", - " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", - " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", - " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", - " ('model.layers.7.mixer.conv1d.weight',\n", - " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", - " \n", - " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", - " \n", - " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", - " \n", - " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", - " \n", - " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", - " ('model.layers.7.mixer.conv1d.bias',\n", - " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", - " ('model.layers.7.mixer.out_proj.weight',\n", - " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", - " 1.4225e-02, 2.8166e-02],\n", - " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", - " 3.7454e-03, 1.7579e-02],\n", - " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", - " 1.4766e-02, -1.8002e-02],\n", - " ...,\n", - " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", - " -1.3072e-02, -1.0412e-02],\n", - " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", - " -1.1047e-02, 1.5815e-02],\n", - " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", - " -2.1834e-02, -2.1826e-02]])),\n", - " ('model.layers.7.mlp.gate_proj.weight',\n", - " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", - " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", - " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", - " ...,\n", - " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", - " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", - " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", - " ('model.layers.7.mlp.up_proj.weight',\n", - " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", - " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", - " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", - " ...,\n", - " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", - " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", - " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", - " ('model.layers.7.mlp.down_proj.weight',\n", - " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", - " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", - " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", - " ...,\n", - " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", - " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", - " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", - " ('model.layers.7.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.8.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.in_proj.weight',\n", - " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", - " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", - " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", - " ...,\n", - " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", - " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", - " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", - " ('model.layers.8.mixer.conv1d.weight',\n", - " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", - " \n", - " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", - " \n", - " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", - " \n", - " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", - " \n", - " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", - " ('model.layers.8.mixer.conv1d.bias',\n", - " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", - " ('model.layers.8.mixer.out_proj.weight',\n", - " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", - " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", - " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", - " ...,\n", - " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", - " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", - " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", - " ('model.layers.8.mlp.gate_proj.weight',\n", - " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", - " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", - " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", - " ...,\n", - " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", - " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", - " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", - " ('model.layers.8.mlp.up_proj.weight',\n", - " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", - " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", - " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", - " ...,\n", - " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", - " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", - " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", - " ('model.layers.8.mlp.down_proj.weight',\n", - " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", - " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", - " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", - " ...,\n", - " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", - " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", - " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", - " ('model.layers.8.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.9.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.in_proj.weight',\n", - " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", - " 7.1750e-04, 1.5247e-02],\n", - " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", - " -1.1202e-02, -2.1476e-02],\n", - " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", - " 1.8544e-02, -2.3633e-02],\n", - " ...,\n", - " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", - " -5.6073e-03, -4.8212e-03],\n", - " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", - " 4.5898e-02, -3.5822e-02],\n", - " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", - " 1.3560e-02, 3.3366e-02]])),\n", - " ('model.layers.9.mixer.conv1d.weight',\n", - " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", - " \n", - " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", - " \n", - " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", - " \n", - " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", - " \n", - " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", - " ('model.layers.9.mixer.conv1d.bias',\n", - " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", - " ('model.layers.9.mixer.out_proj.weight',\n", - " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", - " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", - " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", - " ...,\n", - " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", - " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", - " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", - " ('model.layers.9.mlp.gate_proj.weight',\n", - " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", - " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", - " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", - " ...,\n", - " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", - " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", - " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", - " ('model.layers.9.mlp.up_proj.weight',\n", - " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", - " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", - " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", - " ...,\n", - " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", - " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", - " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", - " ('model.layers.9.mlp.down_proj.weight',\n", - " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", - " -3.0707e-02, -1.3281e-02],\n", - " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", - " -1.1585e-03, -2.7604e-02],\n", - " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", - " 1.1647e-02, 1.1245e-02],\n", - " ...,\n", - " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", - " -4.5122e-03, 4.0012e-03],\n", - " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", - " 2.3422e-02, -3.2098e-02],\n", - " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", - " 2.2118e-02, 1.0696e-02]])),\n", - " ('model.layers.9.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.10.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.in_proj.weight',\n", - " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", - " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", - " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", - " ...,\n", - " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", - " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", - " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", - " ('model.layers.10.mixer.conv1d.weight',\n", - " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", - " \n", - " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", - " \n", - " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", - " \n", - " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", - " \n", - " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", - " ('model.layers.10.mixer.conv1d.bias',\n", - " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", - " ('model.layers.10.mixer.out_proj.weight',\n", - " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", - " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", - " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", - " ...,\n", - " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", - " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", - " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", - " ('model.layers.10.mlp.gate_proj.weight',\n", - " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", - " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", - " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", - " ...,\n", - " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", - " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", - " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", - " ('model.layers.10.mlp.up_proj.weight',\n", - " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", - " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", - " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", - " ...,\n", - " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", - " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", - " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", - " ('model.layers.10.mlp.down_proj.weight',\n", - " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", - " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", - " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", - " ...,\n", - " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", - " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", - " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", - " ('model.layers.10.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.11.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.in_proj.weight',\n", - " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", - " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", - " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", - " ...,\n", - " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", - " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", - " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", - " ('model.layers.11.mixer.conv1d.weight',\n", - " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", - " \n", - " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", - " \n", - " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", - " \n", - " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", - " \n", - " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", - " ('model.layers.11.mixer.conv1d.bias',\n", - " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", - " ('model.layers.11.mixer.out_proj.weight',\n", - " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", - " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", - " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", - " ...,\n", - " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", - " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", - " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", - " ('model.layers.11.mlp.gate_proj.weight',\n", - " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", - " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", - " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", - " ...,\n", - " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", - " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", - " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", - " ('model.layers.11.mlp.up_proj.weight',\n", - " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", - " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", - " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", - " ...,\n", - " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", - " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", - " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", - " ('model.layers.11.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", - " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", - " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", - " ...,\n", - " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", - " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", - " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", - " ('model.layers.11.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.12.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.in_proj.weight',\n", - " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", - " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", - " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", - " ...,\n", - " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", - " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", - " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", - " ('model.layers.12.mixer.conv1d.weight',\n", - " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", - " \n", - " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", - " \n", - " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", - " \n", - " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", - " \n", - " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", - " ('model.layers.12.mixer.conv1d.bias',\n", - " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", - " ('model.layers.12.mixer.out_proj.weight',\n", - " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", - " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", - " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", - " ...,\n", - " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", - " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", - " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", - " ('model.layers.12.mlp.gate_proj.weight',\n", - " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", - " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", - " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", - " ...,\n", - " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", - " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", - " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", - " ('model.layers.12.mlp.up_proj.weight',\n", - " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", - " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", - " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", - " ...,\n", - " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", - " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", - " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", - " ('model.layers.12.mlp.down_proj.weight',\n", - " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", - " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", - " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", - " ...,\n", - " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", - " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", - " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", - " ('model.layers.12.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.13.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.in_proj.weight',\n", - " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", - " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", - " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", - " ...,\n", - " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", - " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", - " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", - " ('model.layers.13.mixer.conv1d.weight',\n", - " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", - " \n", - " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", - " \n", - " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", - " \n", - " ...,\n", - " \n", - " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", - " \n", - " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", - " \n", - " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", - " ('model.layers.13.mixer.conv1d.bias',\n", - " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", - " ('model.layers.13.mixer.out_proj.weight',\n", - " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", - " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", - " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", - " ...,\n", - " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", - " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", - " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", - " ('model.layers.13.mlp.gate_proj.weight',\n", - " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", - " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", - " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", - " ...,\n", - " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", - " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", - " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", - " ('model.layers.13.mlp.up_proj.weight',\n", - " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", - " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", - " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", - " ...,\n", - " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", - " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", - " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", - " ('model.layers.13.mlp.down_proj.weight',\n", - " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", - " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", - " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", - " ...,\n", - " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", - " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", - " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", - " ('model.layers.13.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.14.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.in_proj.weight',\n", - " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", - " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", - " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", - " ...,\n", - " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", - " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", - " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", - " ('model.layers.14.mixer.conv1d.weight',\n", - " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", - " \n", - " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", - " \n", - " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", - " \n", - " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", - " \n", - " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", - " ('model.layers.14.mixer.conv1d.bias',\n", - " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", - " ('model.layers.14.mixer.out_proj.weight',\n", - " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", - " -2.4306e-02, 5.7698e-03],\n", - " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", - " -1.4523e-02, 1.2266e-02],\n", - " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", - " 1.0787e-02, 8.4053e-03],\n", - " ...,\n", - " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", - " 1.5601e-02, -6.3227e-03],\n", - " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", - " 1.2259e-02, 5.5500e-02],\n", - " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", - " 1.8864e-02, -1.2508e-02]])),\n", - " ('model.layers.14.mlp.gate_proj.weight',\n", - " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", - " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", - " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", - " ...,\n", - " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", - " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", - " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", - " ('model.layers.14.mlp.up_proj.weight',\n", - " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", - " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", - " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", - " ...,\n", - " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", - " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", - " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", - " ('model.layers.14.mlp.down_proj.weight',\n", - " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", - " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", - " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", - " ...,\n", - " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", - " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", - " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", - " ('model.layers.14.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.15.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.in_proj.weight',\n", - " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", - " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", - " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", - " ...,\n", - " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", - " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", - " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", - " ('model.layers.15.mixer.conv1d.weight',\n", - " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", - " \n", - " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", - " \n", - " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", - " \n", - " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", - " \n", - " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", - " ('model.layers.15.mixer.conv1d.bias',\n", - " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", - " ('model.layers.15.mixer.out_proj.weight',\n", - " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", - " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", - " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", - " ...,\n", - " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", - " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", - " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", - " ('model.layers.15.mlp.gate_proj.weight',\n", - " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", - " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", - " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", - " ...,\n", - " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", - " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", - " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", - " ('model.layers.15.mlp.up_proj.weight',\n", - " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", - " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", - " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", - " ...,\n", - " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", - " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", - " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", - " ('model.layers.15.mlp.down_proj.weight',\n", - " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", - " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", - " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", - " ...,\n", - " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", - " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", - " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", - " ('model.layers.15.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.16.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.in_proj.weight',\n", - " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", - " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", - " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", - " ...,\n", - " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", - " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", - " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", - " ('model.layers.16.mixer.conv1d.weight',\n", - " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", - " \n", - " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", - " \n", - " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", - " \n", - " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", - " \n", - " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", - " ('model.layers.16.mixer.conv1d.bias',\n", - " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", - " ('model.layers.16.mixer.out_proj.weight',\n", - " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", - " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", - " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", - " ...,\n", - " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", - " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", - " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", - " ('model.layers.16.mlp.gate_proj.weight',\n", - " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", - " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", - " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", - " ...,\n", - " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", - " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", - " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", - " ('model.layers.16.mlp.up_proj.weight',\n", - " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", - " -1.2243e-02, -1.3005e-02],\n", - " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", - " -1.5193e-02, 2.6111e-03],\n", - " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", - " 1.0694e-02, 3.7682e-03],\n", - " ...,\n", - " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", - " -4.8128e-03, 1.6899e-02],\n", - " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", - " -5.5857e-03, 4.0449e-03],\n", - " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", - " -8.0242e-03, -1.0925e-02]])),\n", - " ('model.layers.16.mlp.down_proj.weight',\n", - " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", - " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", - " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", - " ...,\n", - " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", - " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", - " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", - " ('model.layers.16.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.17.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.in_proj.weight',\n", - " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", - " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", - " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", - " ...,\n", - " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", - " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", - " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", - " ('model.layers.17.mixer.conv1d.weight',\n", - " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", - " \n", - " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", - " \n", - " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", - " \n", - " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", - " \n", - " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", - " ('model.layers.17.mixer.conv1d.bias',\n", - " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", - " ('model.layers.17.mixer.out_proj.weight',\n", - " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", - " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", - " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", - " ...,\n", - " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", - " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", - " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", - " ('model.layers.17.mlp.gate_proj.weight',\n", - " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", - " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", - " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", - " ...,\n", - " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", - " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", - " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", - " ('model.layers.17.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", - " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", - " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", - " ...,\n", - " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", - " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", - " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", - " ('model.layers.17.mlp.down_proj.weight',\n", - " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", - " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", - " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", - " ...,\n", - " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", - " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", - " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", - " ('model.layers.17.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.18.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.in_proj.weight',\n", - " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", - " -6.8715e-03, -1.3565e-02],\n", - " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", - " -1.0968e-02, -8.0445e-03],\n", - " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", - " 2.0198e-03, -1.7366e-03],\n", - " ...,\n", - " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", - " 1.4452e-02, -2.5152e-02],\n", - " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", - " 3.8135e-05, 1.7263e-02],\n", - " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", - " -1.1129e-02, -9.6768e-03]])),\n", - " ('model.layers.18.mixer.conv1d.weight',\n", - " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", - " \n", - " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", - " \n", - " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", - " \n", - " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", - " \n", - " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", - " ('model.layers.18.mixer.conv1d.bias',\n", - " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", - " ('model.layers.18.mixer.out_proj.weight',\n", - " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", - " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", - " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", - " ...,\n", - " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", - " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", - " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", - " ('model.layers.18.mlp.gate_proj.weight',\n", - " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", - " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", - " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", - " ...,\n", - " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", - " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", - " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", - " ('model.layers.18.mlp.up_proj.weight',\n", - " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", - " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", - " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", - " ...,\n", - " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", - " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", - " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", - " ('model.layers.18.mlp.down_proj.weight',\n", - " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", - " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", - " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", - " ...,\n", - " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", - " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", - " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", - " ('model.layers.18.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.19.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.in_proj.weight',\n", - " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", - " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", - " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", - " ...,\n", - " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", - " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", - " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", - " ('model.layers.19.mixer.conv1d.weight',\n", - " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", - " \n", - " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", - " \n", - " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", - " \n", - " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", - " \n", - " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", - " ('model.layers.19.mixer.conv1d.bias',\n", - " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", - " ('model.layers.19.mixer.out_proj.weight',\n", - " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", - " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", - " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", - " ...,\n", - " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", - " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", - " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", - " ('model.layers.19.mlp.gate_proj.weight',\n", - " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", - " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", - " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", - " ...,\n", - " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", - " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", - " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", - " ('model.layers.19.mlp.up_proj.weight',\n", - " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", - " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", - " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", - " ...,\n", - " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", - " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", - " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", - " ('model.layers.19.mlp.down_proj.weight',\n", - " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", - " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", - " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", - " ...,\n", - " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", - " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", - " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", - " ('model.layers.19.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.20.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.in_proj.weight',\n", - " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", - " -2.3083e-02, -6.2962e-02],\n", - " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", - " 9.0683e-03, 2.1583e-04],\n", - " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", - " 3.2650e-03, -3.1967e-02],\n", - " ...,\n", - " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", - " 1.0566e-02, -4.9561e-03],\n", - " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", - " -7.5403e-03, -1.3599e-02],\n", - " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", - " -3.3432e-02, 3.8337e-03]])),\n", - " ('model.layers.20.mixer.conv1d.weight',\n", - " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", - " \n", - " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", - " \n", - " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", - " \n", - " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", - " \n", - " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", - " ('model.layers.20.mixer.conv1d.bias',\n", - " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", - " ('model.layers.20.mixer.out_proj.weight',\n", - " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", - " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", - " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", - " ...,\n", - " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", - " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", - " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", - " ('model.layers.20.mlp.gate_proj.weight',\n", - " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", - " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", - " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", - " ...,\n", - " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", - " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", - " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", - " ('model.layers.20.mlp.up_proj.weight',\n", - " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", - " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", - " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", - " ...,\n", - " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", - " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", - " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", - " ('model.layers.20.mlp.down_proj.weight',\n", - " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", - " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", - " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", - " ...,\n", - " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", - " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", - " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", - " ('model.layers.20.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.21.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.in_proj.weight',\n", - " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", - " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", - " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", - " ...,\n", - " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", - " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", - " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", - " ('model.layers.21.mixer.conv1d.weight',\n", - " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", - " \n", - " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", - " \n", - " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", - " \n", - " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", - " \n", - " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", - " ('model.layers.21.mixer.conv1d.bias',\n", - " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", - " ('model.layers.21.mixer.out_proj.weight',\n", - " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", - " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", - " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", - " ...,\n", - " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", - " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", - " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", - " ('model.layers.21.mlp.gate_proj.weight',\n", - " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", - " -5.8295e-03, -1.4284e-02],\n", - " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", - " -3.9076e-02, 5.3379e-03],\n", - " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", - " -2.5180e-02, 6.1512e-04],\n", - " ...,\n", - " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", - " -7.2165e-03, 1.6297e-02],\n", - " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", - " -7.7709e-03, 4.3964e-04],\n", - " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", - " 5.0683e-02, -2.5481e-02]])),\n", - " ('model.layers.21.mlp.up_proj.weight',\n", - " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", - " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", - " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", - " ...,\n", - " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", - " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", - " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", - " ('model.layers.21.mlp.down_proj.weight',\n", - " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", - " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", - " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", - " ...,\n", - " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", - " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", - " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", - " ('model.layers.21.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.22.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.in_proj.weight',\n", - " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", - " -1.5757e-02, -1.5315e-03],\n", - " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", - " 8.9105e-05, -3.8658e-03],\n", - " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", - " -1.6604e-02, -4.5245e-03],\n", - " ...,\n", - " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", - " 2.3792e-02, -5.5165e-03],\n", - " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", - " 5.2496e-03, -6.0055e-03],\n", - " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", - " -1.5446e-02, -2.7164e-02]])),\n", - " ('model.layers.22.mixer.conv1d.weight',\n", - " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", - " \n", - " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", - " \n", - " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", - " \n", - " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", - " \n", - " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", - " ('model.layers.22.mixer.conv1d.bias',\n", - " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", - " ('model.layers.22.mixer.out_proj.weight',\n", - " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", - " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", - " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", - " ...,\n", - " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", - " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", - " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", - " ('model.layers.22.mlp.gate_proj.weight',\n", - " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", - " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", - " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", - " ...,\n", - " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", - " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", - " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", - " ('model.layers.22.mlp.up_proj.weight',\n", - " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", - " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", - " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", - " ...,\n", - " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", - " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", - " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", - " ('model.layers.22.mlp.down_proj.weight',\n", - " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", - " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", - " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", - " ...,\n", - " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", - " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", - " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", - " ('model.layers.22.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.23.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.in_proj.weight',\n", - " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", - " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", - " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", - " ...,\n", - " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", - " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", - " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", - " ('model.layers.23.mixer.conv1d.weight',\n", - " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", - " \n", - " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", - " \n", - " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", - " \n", - " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", - " \n", - " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", - " ('model.layers.23.mixer.conv1d.bias',\n", - " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", - " ('model.layers.23.mixer.out_proj.weight',\n", - " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", - " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", - " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", - " ...,\n", - " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", - " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", - " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", - " ('model.layers.23.mlp.gate_proj.weight',\n", - " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", - " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", - " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", - " ...,\n", - " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", - " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", - " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", - " ('model.layers.23.mlp.up_proj.weight',\n", - " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", - " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", - " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", - " ...,\n", - " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", - " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", - " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", - " ('model.layers.23.mlp.down_proj.weight',\n", - " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", - " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", - " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", - " ...,\n", - " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", - " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", - " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", - " ('model.layers.23.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.24.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.in_proj.weight',\n", - " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", - " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", - " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", - " ...,\n", - " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", - " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", - " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", - " ('model.layers.24.mixer.conv1d.weight',\n", - " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", - " \n", - " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", - " \n", - " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", - " \n", - " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", - " \n", - " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", - " ('model.layers.24.mixer.conv1d.bias',\n", - " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", - " ('model.layers.24.mixer.out_proj.weight',\n", - " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", - " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", - " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", - " ...,\n", - " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", - " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", - " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", - " ('model.layers.24.mlp.gate_proj.weight',\n", - " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", - " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", - " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", - " ...,\n", - " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", - " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", - " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", - " ('model.layers.24.mlp.up_proj.weight',\n", - " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", - " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", - " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", - " ...,\n", - " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", - " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", - " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", - " ('model.layers.24.mlp.down_proj.weight',\n", - " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", - " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", - " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", - " ...,\n", - " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", - " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", - " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", - " ('model.layers.24.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.25.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.in_proj.weight',\n", - " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", - " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", - " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", - " ...,\n", - " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", - " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", - " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", - " ('model.layers.25.mixer.conv1d.weight',\n", - " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", - " \n", - " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", - " \n", - " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", - " \n", - " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", - " \n", - " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", - " ('model.layers.25.mixer.conv1d.bias',\n", - " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", - " ('model.layers.25.mixer.out_proj.weight',\n", - " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", - " -1.8566e-02, 1.3872e-02],\n", - " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", - " 7.3446e-03, -1.6368e-02],\n", - " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", - " -1.6554e-02, 7.2751e-03],\n", - " ...,\n", - " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", - " -1.0262e-02, -4.9292e-03],\n", - " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", - " 2.2522e-02, -6.0243e-03],\n", - " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", - " -5.7889e-03, 1.1855e-02]])),\n", - " ('model.layers.25.mlp.gate_proj.weight',\n", - " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", - " -1.2723e-02, -9.7756e-03],\n", - " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", - " -1.5727e-02, -9.8748e-03],\n", - " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", - " 1.2017e-04, -1.5734e-02],\n", - " ...,\n", - " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", - " 1.8795e-02, 1.0221e-02],\n", - " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", - " 8.5950e-05, 1.0542e-02],\n", - " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", - " -3.0686e-02, -9.6116e-03]])),\n", - " ('model.layers.25.mlp.up_proj.weight',\n", - " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", - " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", - " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", - " ...,\n", - " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", - " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", - " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", - " ('model.layers.25.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", - " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", - " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", - " ...,\n", - " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", - " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", - " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", - " ('model.layers.25.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.26.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.in_proj.weight',\n", - " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", - " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", - " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", - " ...,\n", - " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", - " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", - " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", - " ('model.layers.26.mixer.conv1d.weight',\n", - " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", - " \n", - " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", - " \n", - " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", - " \n", - " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", - " \n", - " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", - " ('model.layers.26.mixer.conv1d.bias',\n", - " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", - " ('model.layers.26.mixer.out_proj.weight',\n", - " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", - " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", - " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", - " ...,\n", - " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", - " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", - " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", - " ('model.layers.26.mlp.gate_proj.weight',\n", - " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", - " 1.3550e-02, -1.1203e-02],\n", - " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", - " 5.3891e-03, -2.8310e-02],\n", - " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", - " -2.8479e-03, -1.3105e-02],\n", - " ...,\n", - " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", - " -1.8015e-02, -1.2234e-02],\n", - " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", - " 1.7886e-02, -4.4699e-02],\n", - " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", - " 3.3468e-03, 2.7677e-02]])),\n", - " ('model.layers.26.mlp.up_proj.weight',\n", - " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", - " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", - " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", - " ...,\n", - " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", - " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", - " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", - " ('model.layers.26.mlp.down_proj.weight',\n", - " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", - " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", - " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", - " ...,\n", - " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", - " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", - " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", - " ('model.layers.26.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.27.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.in_proj.weight',\n", - " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", - " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", - " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", - " ...,\n", - " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", - " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", - " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", - " ('model.layers.27.mixer.conv1d.weight',\n", - " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", - " \n", - " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", - " \n", - " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", - " \n", - " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", - " \n", - " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", - " ('model.layers.27.mixer.conv1d.bias',\n", - " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", - " ('model.layers.27.mixer.out_proj.weight',\n", - " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", - " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", - " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", - " ...,\n", - " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", - " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", - " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", - " ('model.layers.27.mlp.gate_proj.weight',\n", - " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", - " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", - " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", - " ...,\n", - " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", - " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", - " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", - " ('model.layers.27.mlp.up_proj.weight',\n", - " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", - " 3.2845e-02, 2.1107e-02],\n", - " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", - " -1.1685e-02, 1.2504e-02],\n", - " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", - " -3.3912e-04, 3.0527e-02],\n", - " ...,\n", - " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", - " 1.9532e-02, 1.3339e-02],\n", - " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", - " 2.8262e-02, -1.2526e-02],\n", - " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", - " -2.0571e-03, -8.3943e-03]])),\n", - " ('model.layers.27.mlp.down_proj.weight',\n", - " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", - " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", - " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", - " ...,\n", - " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", - " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", - " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", - " ('model.layers.27.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('lm_head.weight',\n", - " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", - " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", - " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", - " ...,\n", - " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", - " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", - " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "N params SSM: 5.305533088\n" - ] - } - ], - "source": [ - "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load State dict into SSM" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'apriel_ssm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" - ] - } - ], - "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", - " save_config=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "24" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.model.layers[0].mixer.n_v_heads" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Try a forward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", - " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", - " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", - " ...,\n", - " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", - " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", - " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", - " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", - " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", - " ...,\n", - " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", - " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", - " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Apriel SSM into HF class" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "import os\n", - "import shutil\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000\"\n", - "modeling_path = \"/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external\"\n", - "# # copy the config.json to the model path\n", - "shutil.copy(os.path.join(modeling_path, \"modeling_ssm_apriel.py\"), os.path.join(model_path, \"modeling_ssm_apriel.py\"))\n", - "shutil.copy(os.path.join(modeling_path, \"configuration_ssm_apriel.py\"), os.path.join(model_path, \"configuration_ssm_apriel.py\"))\n", - "\n", - "tokenizer_path = \"/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/\"\n", - "# # cp tokenizer*\n", - "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer.json\"), os.path.join(model_path, \"tokenizer.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer_config.json\"), os.path.join(model_path, \"tokenizer_config.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"special_tokens_map.json\"), os.path.join(model_path, \"special_tokens_map.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"vocab.json\"), os.path.join(model_path, \"vocab.json\"))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.08s/it]\n" - ] - } - ], - "source": [ - "\n", - "apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device=\"cuda\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "config = apriel_ssm.config" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Mamba in Llama: SSM hybrid " - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer\n", - "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "\n", - "# d_xb = config.num_key_value_heads * config.head_dim\n", - "d_inner = config.num_attention_heads * config.head_dim\n", - "d_state = config.head_dim\n", - "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", - " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", - " ssm_cfg={\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 24,\n", - " \"n_qk_heads\": 24,\n", - " # \"d_xb\": d_xb,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_inner\": 24 * 128, # num_heads * head_dim\n", - " })\n", - "# hybrdif_apriel_config" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "metadata": {}, - "outputs": [], - "source": [ - "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - ")" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hybrid_apriel_model.layers[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(hybrid_apriel_model.layers[0], AprielSSMDecoderLayer)" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": {}, - "outputs": [], - "source": [ - "device = \"cpu\" #if torch.cuda.is_available() else \"cpu\"\n", - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "metadata": {}, - "outputs": [ - { - "ename": "OutOfMemoryError", - "evalue": "CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[73], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)" - ] - } - ], - "source": [ - "hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[79], line 2\u001b[0m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstatic_inputs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:1043\u001b[0m, in \u001b[0;36mAprielSSMHybridModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, inference_params, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[1;32m 1042\u001b[0m all_hidden_states \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (hidden_states,)\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1056\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(decoder_layer, AprielDecoderLayer):\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:805\u001b[0m, in \u001b[0;36mAprielSSMDecoderLayer.forward\u001b[0;34m(self, hidden_states, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 803\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[0;32m--> 805\u001b[0m mixer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 807\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 810\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m mixer_outputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto(residual\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;241m+\u001b[39m residual\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:460\u001b[0m, in \u001b[0;36mDiscreteMamba2.forward\u001b[0;34m(self, u, return_mixer_matrix, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# Project input\u001b[39;00m\n\u001b[1;32m 459\u001b[0m xBCzA_log \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[0;32m--> 460\u001b[0m xBC, z, A_log \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 461\u001b[0m \u001b[43m \u001b[49m\u001b[43mxBCzA_log\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 462\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 463\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_qk_heads\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 464\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_v_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 467\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\u001b[39;00m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\u001b[39;00m\n\u001b[1;32m 473\u001b[0m xBC_t \u001b[38;5;241m=\u001b[39m rearrange(xBC[:, :seqlen, :], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/functional.py:196\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(tensor, split_size_or_sections, dim)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 191\u001b[0m split, (tensor,), tensor, split_size_or_sections, dim\u001b[38;5;241m=\u001b[39mdim)\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# Overwriting reason:\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;66;03m# This dispatches to two ATen functions depending on the type of\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# split_size_or_sections. The branching code is in _tensor.py, which we\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# call here.\u001b[39;00m\n\u001b[0;32m--> 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_size_or_sections\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/_tensor.py:917\u001b[0m, in \u001b[0;36mTensor.split\u001b[0;34m(self, split_size, dim)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_VF\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m, split_size, dim) \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[1;32m 916\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 917\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit_with_sizes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mRuntimeError\u001b[0m: split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]" - ] - } - ], - "source": [ - "\n", - "hybrid_apriel_model.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.44it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [], - "source": [ - "# Innitialization using k, q, v from Apriel transformer\n", - "def expand_k_q(k):\n", - " Hq = config.num_attention_heads\n", - " Hk = config.num_key_value_heads\n", - " d_head = config.head_dim\n", - " d = k.shape[-1]\n", - " \n", - " # Expand k\n", - " repeat_factor = Hq // Hk\n", - " k_expanded = k.view(Hk, d_head, d)\n", - " k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)\n", - " k_expanded = k_expanded.view(d_head * Hq, d)\n", - " return k_expanded\n", - "\n", - "for block_h, block_t in zip(hybrid_apriel_model.layers, apriel_model.model.layers):\n", - " # print(isinstance(block_h, AprielSSMDecoderLayer))\n", - " if isinstance(block_h, AprielSSMDecoderLayer):\n", - " # print(block_h.mixer.n_v_heads)\n", - " # print(block_t.self_attn.v_proj.weight.shape)\n", - " # print(block_h.mixer.in_proj.weight.shape)\n", - "\n", - " # print(block_h.mixer.in_proj.weight.shape)\n", - " # print(block_t.self_attn.v_proj.weight.shape)\n", - " block_h.mlp.load_state_dict(block_t.mlp.state_dict())\n", - " block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())\n", - " block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())\n", - " block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())\n", - " # [x B C z A_log]\n", - " # print(block_h.mixer.d_inner)\n", - " # init x, but interleave to address GQA\n", - " v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)\n", - " block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)\n", - " # init k, but interleave to address GQA\n", - " k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)\n", - " block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)\n", - " # init C ewith Q\n", - " block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1024, 4096])" - ] - }, - "execution_count": 124, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "block_t.self_attn.v_proj.weight.data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "d_xb = config.num_key_value_heads * config.head_dim\n", - "ssm_layers = [2,4,8]\n", - "attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]\n", - "model_name = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "ngroups = config.num_attention_heads # n heads\n", - "d_inner = config.head_dim * config.num_attention_heads\n", - "headdim = 128 # d_state\n", - "d_state = config.head_dim\n", - "d_model = config.hidden_size \n", - "assert d_inner == ngroups * d_state\n", - "\n", - "mamba_config = AprielSSMConfig(\n", - " ssm_cfg={\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 24,\n", - " \"n_qk_heads\": 24,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_inner\": 24 * headdim, # num_heads * head_dim\n", - " },\n", - " vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, \n", - " mamba_config, \n", - " attn_layers=attn_layers, \n", - " init_with_kqvo=True, \n", - " attn_implementation=\"flash_attention_2\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hymba2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 94537c331..38ad5edfe 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -25,13 +25,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM self._model = AprielSSMForCausalLM.from_pretrained( pretrained, @@ -57,3 +57,56 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): max_length=max_length, **generation_kwargs, ) + + +@register_model("apriel_hybrid_ssm") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/configuration_mtp_llamba.py rename to fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/modeling_mtp_llamba.py rename to fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 5863f9030..bb6d54579 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -15,7 +15,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat try: from fast_llm.layers.ssm.config import SSMConfig @@ -112,24 +112,102 @@ def get_hf_llamba_out(input_ids, path, format): return output, parameter_sum +# @pytest.mark.slow +# @pytest.mark.skipif( +# not run_test or LMHeadModel is None, +# reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", +# ) +# def test_load_from_llamba_checkpoint(distributed_config): +# """ +# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. +# """ +# vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json +# batch_size = 2 +# seq_length = 32 + +# path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") +# format = LLambaHuggingfaceCheckpointFormat + +# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") +# hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) +# hf_logits = hf_logits["logits"].cpu() + +# # Create checkpoint load config +# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) +# # Initialize model +# model = HybridSSMModel.from_pretrained(checkpoint_config) +# param_sum = 0 +# for stage in model.stages: +# for fsdp in stage.fsdps: +# if hasattr(fsdp, "_weight_shard"): +# param_sum += torch.sum(fsdp._weight_shard).item() +# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + +# # model = GPTModel.from_pretrained(checkpoint_config) +# assert model.config.base_model.vocab_size == vocab_size +# schedule_config = ScheduleConfig() +# with NoAutoValidate(): +# batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) +# batch_config.setup(distributed_config) +# batch_config.validate() +# schedule_runner = ScheduleRunner( +# config=schedule_config, +# multi_stage=model, +# distributed_config=model.distributed.config, +# ) +# schedule = Schedule( +# multi_stage=model, +# batch_config=batch_config, +# schedule_config=schedule_config, +# distributed_config=model.distributed.config, +# phase=PhaseType.inference, +# ) +# schedule_runner.setup(model.distributed, optimizer=None) + +# common_kwargs = { +# TransformerKwargs.sequence_first: True, +# TransformerKwargs.grad_output: False, +# } +# input_data = [(x, common_kwargs)] + +# losses, success, metrics = schedule_runner.run_step( +# iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True +# ) + +# logits = input_data[0][1]["logits"].cpu() +# assert torch.allclose(logits, hf_logits, atol=1e-2) + + +def get_hf_apriel_hybrid_out(input_ids, path, format): + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum + + @pytest.mark.slow @pytest.mark.skipif( - not run_test or LMHeadModel is None, - reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", + not run_test, + reason=f"Skipping because no CUDA available or Mamba not installed", ) -def test_load_from_llamba_checkpoint(distributed_config): +def test_load_from_hybridssm_checkpoint(distributed_config): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json batch_size = 2 seq_length = 32 - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat + path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") + format = AprielSSMHHybridHuggingfaceCheckpointFormat x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) + hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) hf_logits = hf_logits["logits"].cpu() # Create checkpoint load config @@ -163,15 +241,21 @@ def test_load_from_llamba_checkpoint(distributed_config): phase=PhaseType.inference, ) schedule_runner.setup(model.distributed, optimizer=None) + from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType + from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies + + rotary_config = RotaryConfig(type=RotaryEmbeddingType.default, theta=10000.0) # or whatever type your model uses + frequencies = get_rotary_frequencies(rotary_config, seq_length, 4096, device="cuda") - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] + from types import SimpleNamespace + batch = SimpleNamespace( + token_ids=x, + sequence_lengths=[[seq_length, seq_length]], + ) + input_data = [(batch)] losses, success, metrics = schedule_runner.run_step( - iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True + iter(input_data), schedule, iteration=0, return_metrics=True, preprocessed=False ) logits = input_data[0][1]["logits"].cpu() @@ -183,7 +267,7 @@ def test_load_from_llamba_checkpoint(distributed_config): "hybrid_block_layout,LAYER_CLS", [ (["m", "t"], MambaLayer), - (["m2", "t"], DiscreteMamba2), + (["m2d", "t"], DiscreteMamba2), ], ids=["mamba", "descrete_mamba2"], ) @@ -251,7 +335,7 @@ def test_mamba_block(distributed_config, distributed): "hybrid_block_layout", [ (["m", "t"]), - (["m2", "t"]), + (["m2d", "t"]), ], ids=["mamba", "descrete_mamba2"], ) @@ -338,3 +422,6 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # }, # losses=losses, # ) + +if __name__ == "__main__": + pytest.main(["-s", __file__]) From 30ad8b8f890e43df61bc733fb6c04da8f9d59889 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 12:32:06 +0000 Subject: [PATCH 069/122] wip --- .../ssm/external/aperiel_ssm/configuration_ssm_apriel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py index c3f7ef38d..6943a3124 100644 --- a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py @@ -96,8 +96,8 @@ def __init__( "bias": False, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - if self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: - logger.warning("Head dim is equal to d_inner // n_qk_heads.") + if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is not equal to d_inner // n_qk_heads.") __all__ = ["AprielConfig"] From 9c4f38f92c26c4a3a44ab67795f9dd3b58840245 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 7 May 2025 15:39:51 +0000 Subject: [PATCH 070/122] layer-lr scale for mlp as well --- fast_llm/layers/transformer/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 9e1e0bcfa..982381720 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -46,7 +46,7 @@ def __init__( self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index ) # PEFT. From 1784dcaf64fd312734042176dea827382214a16c Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 21:55:30 +0000 Subject: [PATCH 071/122] wip --- fast_llm/models/ssm/config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7e0a69fd8..15a75ac2e 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -100,19 +100,19 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] - len_block_layout = len(self.hybrid_block_layout) - if len_block_layout != self.transformer.num_layers: - if self.transformer.num_layers % len_block_layout != 0: + + if len(self.hybrid_block_layout) != self.transformer.num_layers: + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" ) - num_repeats = int(self.transformer.num_layers // len_block_layout) + num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) logger.warning( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" ) self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len_block_layout, self.transformer.num_layers) + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) Assert.custom( lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", From 1e3cc2847887070c90d366ccf8497babd3feb661 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 22:01:11 +0000 Subject: [PATCH 072/122] nvm --- fast_llm/models/ssm/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 39aaa6e9f..357a26c04 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -141,7 +141,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear From 2dc945b9402dbbe4379458f8238e5ab2fb2b4cff Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 12:48:08 +0000 Subject: [PATCH 073/122] hybrid modeling --- .../configuration_ssm_hybrid_apriel.py | 10 +- .../modeling_ssm_hybrid_apriel.py | 288 ++++++++++++++---- .../ssm/external/eval/apriel_eval_wrapper.py | 30 +- 3 files changed, 250 insertions(+), 78 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py index b030150ce..1d230bb67 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -431,7 +431,7 @@ def __init__( **kwargs, ) - self.ssm_cfg = ssm_cfg or { + ssm_defaults = { "d_state": 64, "n_v_heads": 24, "n_qk_heads": 24, @@ -439,8 +439,10 @@ def __init__( "chunk_size": 128, "activation": "identity", "bias": False, + "d_conv": 4, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - - -__all__ = ["AprielConfig"] + self.ssm_cfg = ssm_cfg or ssm_defaults + for k, v in ssm_defaults.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 950327df9..d6fd35185 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -8,7 +8,6 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.utils.generation import GenerationMixin from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache @@ -29,14 +28,133 @@ logger = logging.get_logger(__name__) +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + @dataclass -class CustomMambaCausalLMOutput(ModelOutput): +class AprielHybridCausalOutput(ModelOutput): """Custom output class for MambaLMHeadModel.""" loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None class AprielRMSNorm(nn.Module): @@ -333,6 +451,7 @@ def materialize_mixer(A_log, B, C, D): return T +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py class DiscreteMamba2(nn.Module): def __init__( self, @@ -424,7 +543,14 @@ def d_output(self): def state_to_tensor(self): return self.layer.state_to_tensor - def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): """ u: (B, L, D) Returns: same shape as u @@ -433,16 +559,17 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) # assert state is None batch, seqlen, dim = u.shape - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if cache_position[0] > 0: # States are updated inplace - out, _ = self.step(u, state) + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, ssm_state, conv_state) return {"hidden_states": out} # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + chunk_size = self.chunk_size if ssm_state is None else seqlen # Pad input to nearest multiple of chunklen padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size @@ -460,11 +587,11 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) dim=-1, ) - if state is not None: + if ssm_state is not None: # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) # Convolutional layer xBC = self.convolutional_forward(xBC, padded_len) @@ -493,12 +620,12 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) C=C, chunk_size=chunk_size, # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + return_final_states=(ssm_state is not None), ) - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) else: y = result @@ -513,7 +640,7 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] return outputs - def step(self, u, state, **kwargs): + def step(self, u, ssm_state, conv_state, **kwargs): """ u: (B D) state: dict of states @@ -521,7 +648,7 @@ def step(self, u, state, **kwargs): """ # Project input - xBCzA_log = self.in_proj(u.squeeze(1)) + xBCzA_log = self.in_proj(u) xBC, z, A_log = torch.split( xBCzA_log, [ @@ -532,8 +659,8 @@ def step(self, u, state, **kwargs): dim=-1, ) - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place + xBC, conv_state = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state) # update state in place x, B, C = torch.split( xBC, @@ -549,7 +676,7 @@ def step(self, u, state, **kwargs): B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - state["ssm"] = state["ssm"].to(x.dtype) + ssm_state = ssm_state.to(x.dtype) zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) y = selective_state_update( @@ -559,7 +686,7 @@ def step(self, u, state, **kwargs): A=-ones, B=B, C=C, - state=state["ssm"], # will be updated in place + state=ssm_state, # will be updated in place dt_bias=zeros, D=zeros, ) @@ -570,7 +697,7 @@ def step(self, u, state, **kwargs): # Norm and gate out = self.out_proj(y * F.silu(z + self.z_bias)) - return out, state + return out, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.in_proj.weight.device @@ -602,16 +729,17 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states """ assert self.layer_idx is not None # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states def convolutional_forward(self, xBC, padded_len): if causal_conv1d_fn is None or self.activation not in [ @@ -724,7 +852,7 @@ def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, d self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( - self, hidden_states: torch.Tensor, inference_params=None, **kwargs + self, hidden_states: torch.Tensor, **kwargs ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: outputs = {} @@ -734,7 +862,7 @@ def forward( mixer_outputs = self.mixer( hidden_states, - inference_params=inference_params, + **kwargs, ) hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual @@ -878,6 +1006,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa factory_kwargs = {"device": device, "dtype": dtype} self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) @@ -913,7 +1042,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -943,7 +1072,11 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() + # past_key_values = HybridMambaAttentionDynamicCache() + logger.warning_once( + "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1133,12 +1266,11 @@ class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) self.model = AprielSSMHybridModel(config) self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @@ -1161,23 +1293,82 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + def forward( self, input_ids: torch.LongTensor = None, position_ids=None, return_hidden_states=False, return_logits=True, - inference_params=None, num_last_tokens=0, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[tuple, CausalLMOutputWithPast]: + # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) + # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model outputs: BaseModelOutputWithPast = self.model( input_ids, return_hidden_states=return_hidden_states, - inference_params=inference_params, position_ids=position_ids, - return_dict=True, + past_key_values=past_key_values, + **kwargs, ) if outputs["last_hidden_state"] is not None and return_logits: @@ -1186,22 +1377,17 @@ def forward( else: outputs["logits"] = None - return CustomMambaCausalLMOutput( + return AprielHybridCausalOutput( loss=None, logits=outputs["logits"], all_hidden_states=outputs.hidden_states, last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, ) - def generate(self, *args, **kwargs): - """ - This is a wrapper to make sure we comply with the HF generation interface for eval harness - """ - return super().generate(*args, **kwargs) - __all__ = [ - "AprielSSMForCausalLM", - "AprielModel", + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", "AprielSSMPreTrainedModel", ] diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 38ad5edfe..a7cf34e4c 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -8,11 +8,10 @@ @register_model("apriel_ssm") class AprielSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + """Wrapper for AprielSSM model for compatibility with lm-evaluation-harness.""" def __init__(self, pretrained, **kwargs) -> None: if "backend" in kwargs: - # rene currently only supports causal models assert kwargs["backend"] == "causal" super().__init__( @@ -61,11 +60,10 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): @register_model("apriel_hybrid_ssm") class AprielHybridSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" def __init__(self, pretrained, **kwargs) -> None: if "backend" in kwargs: - # rene currently only supports causal models assert kwargs["backend"] == "causal" super().__init__( @@ -80,7 +78,7 @@ def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig - self._config = AprielSSMHybridConfig.from_pretrained(pretrained) + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" @@ -89,24 +87,10 @@ def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype] self._model = AprielSSMHybridForCausalLM.from_pretrained( pretrained, device=self._device, - dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), - trust_remote_code=True, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, ) def _model_generate(self, context, max_length, stop, **generation_kwargs): - """Generate text from the model.""" - for key in ("do_sample", "attention_mask"): - if key in generation_kwargs: - generation_kwargs.pop(key) - - # The custom GenerationMixin imported from mamba_ssm currently does not support - # passing stopping criteria. - # For the time being, we simply generate to max length, then truncate (equivalent result). - # This should be revisited to speed up generation - # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) - - return self.model.generate( - input_ids=context, - max_length=max_length, - **generation_kwargs, - ) + # FOR now evaluating with non-generation tasks + raise NotImplementedError("Generation not implemented yet for AprielHybridSSMWrapper") From 4277e6779ab78d8dc62615d729e54c685cfa5d92 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 12:49:57 +0000 Subject: [PATCH 074/122] modeling --- .../modeling_ssm_hybrid_apriel.py | 75 +++++++++---------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index d6fd35185..95c09e0c2 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -699,28 +699,28 @@ def step(self, u, ssm_state, conv_state, **kwargs): return out, ssm_state - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): """ @@ -800,7 +800,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - inference_params=None, # just to be compatible with SSM block **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -878,11 +877,11 @@ def forward( return outputs - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # """Allocate inference cache for the model.""" + # if getattr(self.mixer, "allocate_inference_cache", None) is None: + # return + # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) APRIEL_START_DOCSTRING = r""" @@ -920,9 +919,9 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) APRIEL_INPUTS_DOCSTRING = r""" @@ -1028,13 +1027,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - cache = {} - for i, layer in enumerate(self.layers): - if isinstance(layer, AprielSSMDecoderLayer): - cache[i] = layer.allocate_inference_cache(*args, **kwargs) - return cache + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # cache = {} + # for i, layer in enumerate(self.layers): + # if isinstance(layer, AprielSSMDecoderLayer): + # cache[i] = layer.allocate_inference_cache(*args, **kwargs) + # return cache @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) def forward( From c71cb16db8df7d637f83853ccd9419162c151177 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:01:10 +0000 Subject: [PATCH 075/122] nvm --- tests/test_ssms.py | 45 +++------------------------------------------ 1 file changed, 3 insertions(+), 42 deletions(-) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 551da7e10..9f3382b6c 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -13,6 +13,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat @@ -182,53 +183,13 @@ def test_load_from_hybridssm_checkpoint(distributed_config): param_sum += torch.sum(fsdp._weight_shard).item() assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - # # model = GPTModel.from_pretrained(checkpoint_config) - # assert model.config.base_model.vocab_size == vocab_size - # schedule_config = ScheduleConfig() - # with NoAutoValidate(): - # batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - # batch_config.setup(distributed_config) - # batch_config.validate() - # schedule_runner = ScheduleRunner( - # config=schedule_config, - # multi_stage=model, - # distributed_config=model.distributed.config, - # ) - # schedule = Schedule( - # multi_stage=model, - # batch_config=batch_config, - # schedule_config=schedule_config, - # distributed_config=model.distributed.config, - # phase=PhaseType.inference, - # ) - # schedule_runner.setup(model.distributed, optimizer=None) - # from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType - # from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies - - # rotary_config = RotaryConfig(type=RotaryEmbeddingType.default, theta=10000.0) # or whatever type your model uses - # frequencies = get_rotary_frequencies(rotary_config, seq_length, 4096, device="cuda") - - # from types import SimpleNamespace - - # batch = SimpleNamespace( - # token_ids=x, - # sequence_lengths=[[seq_length, seq_length]], - # ) - # input_data = [batch] - # losses, success, metrics = schedule_runner.run_step( - # iter(input_data), schedule, iteration=0, return_metrics=True, preprocessed=False - # ) - - # logits = input_data[0][1]["logits"].cpu() - # assert torch.allclose(logits, hf_logits, atol=1e-2) - @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", [ - (["m", "t"], MambaLayer), - (["m2d", "t"], DiscreteMamba2), + ([SSMBlockType.mamba, SSMBlockType.transformer], MambaLayer), + ([SSMBlockType.mamba2_discrete, SSMBlockType.transformer], DiscreteMamba2), ], ids=["mamba", "discrete_mamba2"], ) From be04c192575be79b3568162d54847c692d4cb56a Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:14:51 +0000 Subject: [PATCH 076/122] output lr scale --- fast_llm/layers/language_model/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ef1e3a376..96ea3f7c9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -146,6 +146,12 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() From 1311f5b28aafe2c3df4e2deab7b8b350cc4d60d7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:15:12 +0000 Subject: [PATCH 077/122] output_lr_scale --- fast_llm/layers/language_model/head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 813dcc076..2cc7730b5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -104,6 +104,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( From baf4011943a4912467ab16930336febeb005d6b6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:44:15 +0000 Subject: [PATCH 078/122] nvm --- .../ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 95c09e0c2..6800158ee 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -9,6 +9,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn +from transformers import GenerationMixin from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter From 6cf26c5d78a52dfcd0fddd02b156e9dd53fdfa40 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 10 May 2025 15:41:02 +0000 Subject: [PATCH 079/122] eval --- fast_llm/models/ssm/external/eval/run_lm_eval.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py index af07869a8..c910bcc38 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -1,6 +1,9 @@ from lm_eval.__main__ import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybridSSMWrapper, + AprielSSMWrapper, +) if __name__ == "__main__": cli_evaluate() From 901d1b6ad38cd6b498e43ada472074bbeb6a3766 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 12:50:33 +0000 Subject: [PATCH 080/122] rename --- .../{aperiel_ssm => apriel_ssm}/configuration_ssm_apriel.py | 0 .../{aperiel_ssm => apriel_ssm}/modeling_ssm_apriel.py | 2 +- fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename fast_llm/models/ssm/external/{aperiel_ssm => apriel_ssm}/configuration_ssm_apriel.py (100%) rename fast_llm/models/ssm/external/{aperiel_ssm => apriel_ssm}/modeling_ssm_apriel.py (99%) diff --git a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py similarity index 100% rename from fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py rename to fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py diff --git a/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py rename to fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py index dd228024c..a46530fcd 100644 --- a/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index a7cf34e4c..02c9176b9 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -24,13 +24,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig + from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM self._model = AprielSSMForCausalLM.from_pretrained( pretrained, From 616c54069802ccf2e65ca872f84e30024bc0ef20 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 13:39:27 +0000 Subject: [PATCH 081/122] per_layer_lr_scale --- fast_llm/layers/common/config.py | 11 +++++++++++ fast_llm/layers/ssm/config.py | 13 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 11 ++++++++++- fast_llm/layers/ssm/mamba_layer.py | 10 ++++++++++ fast_llm/layers/transformer/config.py | 10 ++-------- 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 16be49879..e8e068c0c 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -11,6 +11,17 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +class LLMBlockConfig(BaseModelConfig): + _abstract = False + + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9faec8799..846ab43d0 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,9 +1,8 @@ import enum -from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig from fast_llm.utils import Assert @@ -34,7 +33,7 @@ class SSMBlockType(str, enum.Enum): @config_class() -class SSMConfig(BaseModelConfig): +class SSMConfig(LLMBlockConfig): _abstract = False # Normalization @@ -122,6 +121,12 @@ class SSMConfig(BaseModelConfig): desc="Inner dimension for Mamba2 blocks.", hint=FieldHint.core, ) + mamba_lr_scale: float = Field( + default=None, + desc="Learning rate scale for Mamba blocks.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 49dacb914..16686fe8e 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py @@ -44,6 +45,8 @@ def __init__( bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -73,6 +76,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_zeros_, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -84,14 +88,18 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + lr_scale=mamba_layer_lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) - self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) # D "skip" parameter self.D = ParameterMeta.from_dims( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -100,6 +108,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) @property diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4704b5228..e44a4e1db 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. @@ -81,6 +82,8 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), @@ -90,6 +93,7 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = None @@ -102,6 +106,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, + layer_lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -110,6 +115,7 @@ def __init__( self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), init_method=kaiming_init_(tdt_rank.size), + lr_scale=mamba_layer_lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( @@ -117,12 +123,14 @@ def __init__( init_method=init_dtprojbias( self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs ), + lr_scale=mamba_layer_lr_scale, ) self.A_log = ParameterMeta.from_dims( (td_inner, td_state), weight_decay=False, init_method=init_A(self.d_state, self.d_inner), + lr_scale=mamba_layer_lr_scale, ) # D "skip" parameter @@ -130,6 +138,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) self.out_proj = Linear( @@ -137,6 +146,7 @@ def __init__( td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c6ea98b19..e4eaac1d3 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -11,7 +11,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig, PeftConfig, PeftType from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -245,7 +245,7 @@ def _validate(self) -> None: @config_class() -class TransformerConfig(BaseModelConfig): +class TransformerConfig(LLMBlockConfig): _abstract = False normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, @@ -496,12 +496,6 @@ class TransformerConfig(BaseModelConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.", From 9af5ee5da24fb693b0a03a3dd722af19ed0f98ce Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 15:00:22 +0000 Subject: [PATCH 082/122] merged also prediction_loss_coefficient from #243 --- fast_llm/layers/common/config.py | 1 + fast_llm/layers/language_model/config.py | 10 ++++++++++ fast_llm/layers/ssm/config.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index e8e068c0c..50ccab01b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -11,6 +11,7 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +@config_class() class LLMBlockConfig(BaseModelConfig): _abstract = False diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 96ea3f7c9..f6d376164 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -152,6 +152,12 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="May be used to freeze the output weights by setting their scale to zero.", hint=FieldHint.feature, ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -170,6 +176,10 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: self.transformer.setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 846ab43d0..6cfe2ebe8 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -121,7 +121,7 @@ class SSMConfig(LLMBlockConfig): desc="Inner dimension for Mamba2 blocks.", hint=FieldHint.core, ) - mamba_lr_scale: float = Field( + mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", hint=FieldHint.feature, From 1a7939bf52ee06c7c2b296364343d223ac017b24 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 15:23:40 +0000 Subject: [PATCH 083/122] added logging in mamba --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 16686fe8e..5526516f2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,3 +1,4 @@ +import logging import math import causal_conv1d @@ -11,6 +12,8 @@ from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale +logger = logging.getLogger(__name__) + """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ @@ -47,6 +50,7 @@ def __init__( self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) From 532d0d577412d1601f5cfc187f1c1e853471ab43 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 17:02:15 +0000 Subject: [PATCH 084/122] no norm layer freezing --- fast_llm/layers/transformer/transformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 0b0d53345..fd56ba087 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -38,9 +38,10 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) - self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) + # layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + # we dont want to freeze norm layers here + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) self._create_mixer() From 834913060a7e7eab7c4d0cfd7dba7e3bc6fa1228 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:07:52 +0000 Subject: [PATCH 085/122] test --- fast_llm/layers/ssm/discrete_mamba2.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 5526516f2..c518e798b 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -74,13 +74,15 @@ def __init__( # TODO: double check innitializations # Projections - self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) + self.in_proj = Linear( + td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size) + ) # , lr_scale=mamba_layer_lr_scale) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -92,10 +94,12 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (td_conv,), + init_method=bias_init_method(self.conv1d_weight), + # , lr_scale=mamba_layer_lr_scale ) # D "skip" parameter @@ -103,7 +107,7 @@ def __init__( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -112,7 +116,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) @property From 023102c4bcab18b05408a791318c780f84d6462b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:19:57 +0000 Subject: [PATCH 086/122] test --- fast_llm/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 611eb9f48..c82a3bf18 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -235,8 +235,8 @@ def __init__( self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) # TODO: re-enable when fixed? - # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - self.requires_grad = requires_grad + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 865da957496f50e8f703e5a52b2be1d5cd9cbeb3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:43:38 +0000 Subject: [PATCH 087/122] debug --- fast_llm/layers/ssm/discrete_mamba2.py | 20 +++++++++++--------- fast_llm/tensor.py | 2 -- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c518e798b..d4f9f84d3 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -75,14 +75,18 @@ def __init__( # TODO: double check innitializations # Projections self.in_proj = Linear( - td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size) - ) # , lr_scale=mamba_layer_lr_scale) + td_model, + td_inner_proj, + bias=bias, + weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, + ) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -94,12 +98,10 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), - init_method=bias_init_method(self.conv1d_weight), - # , lr_scale=mamba_layer_lr_scale + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) # D "skip" parameter @@ -107,7 +109,7 @@ def __init__( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -116,7 +118,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) @property diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index c82a3bf18..849307563 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,9 +234,7 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: re-enable when fixed? self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 87c93d32583d1117b1410c1d25dd1b97de656f38 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 21:06:50 +0000 Subject: [PATCH 088/122] comment --- fast_llm/layers/transformer/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index fd56ba087..b51ba1e94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -38,8 +38,8 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - # layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - # we dont want to freeze norm layers here + # Note, layer_lr_scale does not impact the norms + # TODO: add a seperate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) From a18b80f575c0be1aacf827eb49bce9d7d295ca9e Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 22:04:18 +0000 Subject: [PATCH 089/122] debug --- fast_llm/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 849307563..611eb9f48 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,9 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # TODO: re-enable when fixed? + # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 40d5437917869a855f7b31ef96296ff5543b6518 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 14 May 2025 13:50:55 +0000 Subject: [PATCH 090/122] wip --- .../modeling_ssm_hybrid_apriel.py | 215 ++++++++++++++++-- .../apriel_ssm/modeling_ssm_apriel.py | 2 + tests/common.py | 8 +- 3 files changed, 204 insertions(+), 21 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 6800158ee..5d8f4cc51 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -29,6 +29,175 @@ logger = logging.get_logger(__name__) +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py class HybridMambaAttentionDynamicCache(DynamicCache): """ @@ -111,14 +280,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") @@ -549,7 +710,7 @@ def forward( u, return_mixer_matrix=False, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, + inference_params=None, **kwargs, ): """ @@ -563,10 +724,12 @@ def forward( ssm_state, conv_state = None, None if past_key_value is not None: ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if cache_position[0] > 0: + if inference_params is not None and inference_params.seqlen_offset > 0: # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, ssm_state, conv_state) + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out return {"hidden_states": out} # Hacky way to initialize state during inference @@ -660,8 +823,8 @@ def step(self, u, ssm_state, conv_state, **kwargs): dim=-1, ) - xBC, conv_state = self.convolutional_step(xBC, conv_state) - conv_state.copy_(conv_state) # update state in place + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place x, B, C = torch.split( xBC, @@ -698,7 +861,7 @@ def step(self, u, ssm_state, conv_state, **kwargs): # Norm and gate out = self.out_proj(y * F.silu(z + self.z_bias)) - return out, ssm_state + return out, ssm_state, conv_state # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): # device = self.in_proj.weight.device @@ -908,6 +1071,13 @@ class AprielSSMPreTrainedModel(PreTrainedModel): config_class = AprielSSMHybridConfig base_model_prefix = "model" _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -1018,6 +1188,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False self.rotary_emb = AprielRotaryEmbedding(config=config) + self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) # Initialize weights and apply final processing self.post_init() @@ -1078,23 +1249,25 @@ def forward( "provided, so no cache will be returned." ) - if cache_position is None: + if cache_position is None and self.has_transformer_layers: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: + if position_ids is None and self.has_transformer_layers: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = ( + self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + if self.has_transformer_layers + else None ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1152,7 +1325,9 @@ def _update_causal_mask( # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( + past_key_values, HybridMambaAttentionStaticCache + ) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py index a46530fcd..09dc8259c 100644 --- a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py @@ -225,7 +225,9 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out return {"hidden_states": out} # Hacky way to initialize state during inference diff --git a/tests/common.py b/tests/common.py index 569d690cc..2f2fd5f79 100644 --- a/tests/common.py +++ b/tests/common.py @@ -63,7 +63,8 @@ f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + # f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + f"model.multi_stage.debug_all_param_gradients=0", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", "model.distributed.timeout=10", @@ -201,6 +202,11 @@ CONFIG_LLAMA_MTP_MEGATRON = None CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", + "model.base_model.embeddings_lr_scale=0", + "model.base_model.transformer.per_layer_lr_scale=[0.1,0.0000001,0.0000001,1,1,.1]", + # "model.base_model.output_lr_scale=0", + # "model.base_model.prediction_loss_coefficient=[1, .5, .5, 0]", + # "model.base_model.cross_entropy_splits=4", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From 72ace3b08cbd9255d8cdf1a64030625dc9e28fe4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 14 May 2025 19:01:51 +0000 Subject: [PATCH 091/122] fix --- fast_llm/engine/checkpoint/safe_load.py | 4 ++-- fast_llm/engine/multi_stage/config.py | 5 +++++ fast_llm/engine/multi_stage/fsdp.py | 11 ++++++++--- fast_llm/engine/multi_stage/multi_stage.py | 7 +------ fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/engine/multi_stage/stage_base.py | 4 ++-- tests/test_checkpoint.py | 3 +-- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 2eec57e02..4b7366e4b 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -40,8 +40,8 @@ def __enter__(self) -> "SafeLoad": triton_fill(self_shard, math.nan) # Reset and count shard pads for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards): - for fsdp_shard in fsdp_shards.values(): - self._loaded += fsdp.reset_shard_pad(fsdp_shard) + for shard_name, fsdp_shard in fsdp_shards.items(): + self._loaded += fsdp.reset_shard_pad(fsdp_shard, shard_name) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80f..40002ce61 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -36,6 +36,11 @@ logger = logging.getLogger(__name__) +class ShardName: + weights = "weights" + grads = "grads" + + class StageMode(str, enum.Enum): # Allow forward and backward passes and optimizer. # TODO: Add mode for forward and backward but not optimizer? diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index e9c84aa30..61d1c7a8e 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode +from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta @@ -246,13 +246,14 @@ def setup( ) self._parameter_buffers[parameter_name] = parameter_buffer - def reset_shard_pad(self, shard: torch.Tensor) -> int: + def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int: assert self._is_setup assert self._mode.on_device # TODO: Needed? # Prevent nans with the padded values # Also ensures a correct parameter count in loading context. - self._weight_shard_meta.validate(shard) + shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta + shard_meta.validate(shard) if self._shard_pad > 0: shard[-self._shard_pad :].zero_() return self._shard_pad @@ -452,5 +453,9 @@ def copy_shard_overlaps( begin, end = self._parameter_range_in_shard(name) for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + Assert.eq(loaded_shards[shard_name].numel(), 0) + continue shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 21d0fe557..497d11108 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -14,7 +14,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.config import ParamGroup @@ -24,11 +24,6 @@ logger = logging.getLogger(__name__) -class ShardName: - weights = "weights" - grads = "grads" - - class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b3..179a94c10 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -156,7 +156,7 @@ def reduce_gradients(self, accumulate=False) -> None: level=self._config.debug_param_gradients, global_=False, ) - if self._config.debug_all_param_gradients: + if self._config.debug_all_param_gradients and fsdp.requires_grad: fsdp.log_shard( name="gradient", shard=fsdp.grad_shard, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index fd50f55c5..e1b444716 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageConfig, StageMode +from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.logging import log_generator @@ -209,7 +209,7 @@ def initialize_weights(self) -> None: meta.init_parameter(parameter, self._distributed) if self.mode.on_device: - fsdp.reset_shard_pad(fsdp.weight_shard) + fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) if self._config.debug_param_init: log_generator("CPU generator after reset", torch.random.default_generator) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e96..042f2bb23 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,8 +14,7 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode -from fast_llm.engine.multi_stage.multi_stage import ShardName +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig from tests.common import ( From 121e9064c0970ff15595f80077d29f88df36f6c4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 14 May 2025 19:43:05 +0000 Subject: [PATCH 092/122] test + comment --- fast_llm/tensor.py | 5 ++--- tests/common.py | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 611eb9f48..ad2d42d1a 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,9 +234,8 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: re-enable when fixed? - # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - self.requires_grad = requires_grad + # TODO: note, this pevents the tes_checkpoints to pass for MODEL=llama-mtp, they pass with `self.requires_grad=requires_grad` instead. However, the model export seem to work as expected, at least for hybrid SSM. + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) diff --git a/tests/common.py b/tests/common.py index 2f2fd5f79..16d5114bf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -203,10 +203,7 @@ CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", "model.base_model.embeddings_lr_scale=0", - "model.base_model.transformer.per_layer_lr_scale=[0.1,0.0000001,0.0000001,1,1,.1]", - # "model.base_model.output_lr_scale=0", - # "model.base_model.prediction_loss_coefficient=[1, .5, .5, 0]", - # "model.base_model.cross_entropy_splits=4", + "model.base_model.transformer.per_layer_lr_scale=[0.1,0,0,1,1,.1]", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From 8a8fa77be71a12a770ae43ee0d369ddec5e189eb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 15:42:04 +0000 Subject: [PATCH 093/122] fix --- fast_llm/engine/multi_stage/fsdp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 61d1c7a8e..d24c8f842 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -455,7 +455,10 @@ def copy_shard_overlaps( for shard_name, shard in shards.items(): # Shards can be empty (frozen weights) if shard.numel() == 0: - Assert.eq(loaded_shards[shard_name].numel(), 0) + continue + if loaded_shards[shard_name].numel() == 0: + shard[begin:end][overlap_mask] = 0 + counter += overlap_count continue shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count From 8e259904a48725a7b2b3e30bd82139fcaeb89fcd Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 15:45:23 +0000 Subject: [PATCH 094/122] add test with frozen weights --- tests/test_checkpoint.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 042f2bb23..9d5c86e95 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -89,6 +89,23 @@ def test_resume(): ) +@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +def test_resume_frozen(): + run_test_script( + f"test_{TEST_MODEL}_resume_frozen", + CONFIG_COMMON + + [ + "training.checkpoint.interval=1", + "training.evaluations.validation.interval=2", + "training.evaluations.validation.iterations=1", + "model.base_model.transformer.mlp_lr_scale=0.", + ], + compare=f"test_{TEST_MODEL}_checkpoint_and_eval", + prepare_fn=_prepare_resume_fn, + compare_fn=_compare_resume_fn, + ) + + def _run_conversion(config: ConversionConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) From 456a0c528e7ed5d6648727c5f2f8fb7c7e18c318 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 20:15:57 +0000 Subject: [PATCH 095/122] add description for tests --- tests/common.py | 3 ++- tests/test_checkpoint.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/common.py b/tests/common.py index 569d690cc..6179957b3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -361,6 +361,7 @@ def run_test_script( config: CompareConfig | None = None, prepare_fn=None, compare_fn=None, + do_compare: bool = True, ): if torch.cuda.device_count() < num_gpus: pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") @@ -413,7 +414,7 @@ def run_test_script( completed_proc = subprocess.run(command, env=env, timeout=60) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare: + if compare and do_compare: if compare_fn is not None: compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) compare_tensor_logs( diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 9d5c86e95..4dfd23a8b 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -75,6 +75,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_resume(): + # Resume from iteration=1 and compare outputs with the baseline run. run_test_script( f"test_{TEST_MODEL}_resume", CONFIG_COMMON @@ -91,6 +92,7 @@ def test_resume(): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_resume_frozen(): + # Resume with frozen mlp. No comparison. run_test_script( f"test_{TEST_MODEL}_resume_frozen", CONFIG_COMMON @@ -102,7 +104,7 @@ def test_resume_frozen(): ], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, - compare_fn=_compare_resume_fn, + do_compare=False, ) From 87efd455af8fa4bf6f8e4a42cdfa73bf89913498 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 12:17:32 +0000 Subject: [PATCH 096/122] 15b model apriel hybrid --- .../configuration_ssm_hybrid_apriel15b.py | 21 + .../modeling_ssm_hybrid_apriel15b.py | 948 ++++++++++++++++++ .../modeling_ssm_hybrid_apriel.py | 2 +- .../ssm/external/eval/apriel_eval_wrapper.py | 88 +- .../ssm/external/make_hybrid_checkpoint.py | 41 + 5 files changed, 1097 insertions(+), 3 deletions(-) create mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py create mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py create mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint.py diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py new file mode 100644 index 000000000..bc2e603cf --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -0,0 +1,21 @@ +from transformers import MistralConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class AprielSSMHybridConfig(MistralConfig): + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py new file mode 100644 index 000000000..ba798a1ca --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -0,0 +1,948 @@ +import copy +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin, MistralModel +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralMLP, + MistralPreTrainedModel, + MistralRMSNorm, +) +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if inference_params is not None and inference_params.seqlen_offset > 0: + # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + +class AprielHybridIdentity(nn.Module): + def __init__(self, config: AprielSSMHybridConfig): + super().__init__() + self.config = config + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return (hidden_states,) + + +class AprielSSMHybridModel(MistralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + config_copy = copy.deepcopy(config) + config_copy.num_hidden_layers = 0 + super().__init__(config_copy, **kwargs) + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx)) + elif type == "t": + blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "i": + blocks.append(AprielHybridIdentity(config)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielHybridPreTrainedModel(MistralPreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + +class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return AprielHybridCausalOutput( + loss=loss, + logits=logits, + all_hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 5d8f4cc51..ddb7d0f77 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -1482,7 +1482,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - empty_past_kv = past_key_values is None + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 02c9176b9..e15de8bb8 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -92,5 +92,89 @@ def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype] ) def _model_generate(self, context, max_length, stop, **generation_kwargs): - # FOR now evaluating with non-generation tasks - raise NotImplementedError("Generation not implemented yet for AprielHybridSSMWrapper") + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm_15b") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( + AprielSSMHybridConfig, + ) + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMHybridForCausalLM, + ) + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py new file mode 100644 index 000000000..a0616ab64 --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py @@ -0,0 +1,41 @@ +import gc + +import click +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +@click.command() +@click.option("--identity_index", type=int, required=True) +@click.option("--save_dir", type=str, required=True) +def main(identity_index: int, save_dir: str): + checkpoint = "ServiceNow-AI/Apriel-Nemotron-15b-Thinker" + config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + + hybrid_block_layout = ["t"] * config.num_hidden_layers + if identity_index >= 0: + hybrid_block_layout[identity_index] = "i" + + hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(), hybrid_block_layout=hybrid_block_layout) + hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config) + hybrid_apriel_model.to(dtype=torch.bfloat16).to(device) + + apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True) + apriel_state_dict = apriel_model.state_dict() + hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False) + + hybrid_apriel_model.save_pretrained(save_dir, save_config=True) + torch.cuda.empty_cache() + del hybrid_apriel_model + del apriel_model + del apriel_state_dict + gc.collect() + + +if __name__ == "__main__": + main() From aafbfb569c1370a683ec4931177edc30a1012e4d Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 12:22:01 +0000 Subject: [PATCH 097/122] nvm --- fast_llm/tensor.py | 1 - tests/common.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index ad2d42d1a..849307563 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,6 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: note, this pevents the tes_checkpoints to pass for MODEL=llama-mtp, they pass with `self.requires_grad=requires_grad` instead. However, the model export seem to work as expected, at least for hybrid SSM. self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) diff --git a/tests/common.py b/tests/common.py index f9cb324ef..fe3120c22 100644 --- a/tests/common.py +++ b/tests/common.py @@ -202,8 +202,6 @@ CONFIG_LLAMA_MTP_MEGATRON = None CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", - "model.base_model.embeddings_lr_scale=0", - "model.base_model.transformer.per_layer_lr_scale=[0.1,0,0,1,1,.1]", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From c7fe8d74af47abc5d3f264349eca4eb4303e8ce9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 19:34:10 +0000 Subject: [PATCH 098/122] nvm --- .../modeling_ssm_hybrid_apriel15b.py | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index ba798a1ca..4b2e5724e 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -10,16 +10,12 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn -from transformers import GenerationMixin, MistralModel +from transformers import GenerationMixin from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralDecoderLayer, - MistralMLP, - MistralPreTrainedModel, - MistralRMSNorm, -) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack from transformers.utils import LossKwargs, logging from transformers.utils.generic import ModelOutput @@ -759,6 +755,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): config_copy = copy.deepcopy(config) config_copy.num_hidden_layers = 0 super().__init__(config_copy, **kwargs) + self.config = config blocks = [] logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): @@ -779,10 +776,11 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class AprielHybridPreTrainedModel(MistralPreTrainedModel): +class AprielHybridPreTrainedModel(PreTrainedModel): config_class = AprielSSMHybridConfig base_model_prefix = "model" _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True @@ -791,8 +789,24 @@ class AprielHybridPreTrainedModel(MistralPreTrainedModel): _supports_static_cache = True _supports_attention_backend = True + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.model = AprielSSMHybridModel(config) @@ -802,6 +816,24 @@ def __init__(self, config, **kwargs): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + def prepare_inputs_for_generation( self, input_ids, From c285e8de0e831dbe4f8e211cd89f9eed70b03aa8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 19:51:30 +0000 Subject: [PATCH 099/122] nvm --- tests/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/common.py b/tests/common.py index fe3120c22..6179957b3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -63,8 +63,7 @@ f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - # f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients=0", + f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", "model.distributed.timeout=10", From 6765d28ec6c77cab4f8bf93b7e2364ff258c7a92 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 21 May 2025 12:46:56 +0000 Subject: [PATCH 100/122] hybrid thinker --- fast_llm/models/ssm/config.py | 12 ++++ fast_llm/models/ssm/conversion.py | 58 +++++++++++++++++++ .../configuration_ssm_hybrid_apriel15b.py | 2 + 3 files changed, 72 insertions(+) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 15a75ac2e..31f7d8792 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -158,6 +158,17 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False @@ -167,6 +178,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): LLambaHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, AprielSSMHHybridHuggingfaceCheckpointFormat, + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 357a26c04..e683ae311 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -7,6 +7,7 @@ from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ConstantImportParamConverter, + IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, ParamConverter, @@ -23,6 +24,7 @@ from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, ) @@ -582,3 +584,59 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: with open(directory / "config.json", "w") as f: json.dump(config, f) + + +class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielThinkerSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index bc2e603cf..e15df990f 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -5,6 +5,8 @@ class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout From b0fe37b72395f338a145c56d44419656ea0f16ba Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 22 May 2025 13:05:12 +0000 Subject: [PATCH 101/122] nvm --- fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index e15de8bb8..37c0311b6 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -118,7 +118,7 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): @register_model("apriel_hybrid_ssm_15b") -class AprielHybridSSMWrapper(HFLM): +class AprielHybrid15bSSMWrapper(HFLM): """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" def __init__(self, pretrained, **kwargs) -> None: From 8f84a498def3378602a339c10f3b4f45d8820cdc Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 22 May 2025 16:57:01 +0000 Subject: [PATCH 102/122] modeling --- .../modeling_ssm_hybrid_apriel15b.py | 117 +++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 4b2e5724e..777fd3cfa 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -1,5 +1,6 @@ import copy from dataclasses import dataclass +from functools import partial from typing import Any, Optional, Union import torch @@ -15,9 +16,15 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_INPUTS_DOCSTRING, + MistralDecoderLayer, + MistralMLP, + MistralModel, + MistralRMSNorm, +) from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, add_start_docstrings_to_model_forward, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -772,6 +779,112 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... From 361aad067b7f92812c21751e09217607bb3012d7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 22 May 2025 18:51:15 +0000 Subject: [PATCH 103/122] wip --- fast_llm/models/ssm/config.py | 14 +++++++------- fast_llm/models/ssm/conversion.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 31f7d8792..d59de1c2e 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -51,13 +51,13 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - if ( - not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout - and not SSMBlockType.mamba.value in self.hybrid_block_layout - ): - raise ValueError( - f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" - ) + # if ( + # not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout + # and not SSMBlockType.mamba.value in self.hybrid_block_layout + # ): + # raise ValueError( + # f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" + # ) if self.ssm.dt_rank is None: mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index e683ae311..c5f22ccc5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -599,6 +599,25 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielThinkerSSMHHybridHuggingfaceCheckpointFormat _default_block_type: str = SSMBlockType.mamba2_discrete.value + _hf_prefix: str = "model" + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + # num_layers = self._model.config.base_model.transformer.num_layers + # # Embedding and output + # if self._model.config.base_model.tie_word_embeddings: + # converters.append( + # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + # ) + # converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) + # else: + # converters.append( + # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + # ) + # converters.append( + # WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + # ) + return converters @classmethod def _create_config_converters(cls) -> list[ParamConverter]: From 65e466f54db87222bf84d9f09a40c99383c2a0d9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 26 May 2025 16:39:43 +0000 Subject: [PATCH 104/122] nvm --- fast_llm/models/ssm/external/eval/run_lm_eval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py index c910bcc38..53c0febab 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -1,6 +1,7 @@ from lm_eval.__main__ import cli_evaluate from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybrid15bSSMWrapper, AprielHybridSSMWrapper, AprielSSMWrapper, ) From 45aa6e4c8e5bd6bcbac9bb9ff6430b1ac44c88e2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 29 May 2025 14:10:55 +0000 Subject: [PATCH 105/122] notebook --- fast_llm/models/ssm/external/15B_hybrid.ipynb | 516 ++++++++++++++++++ 1 file changed, 516 insertions(+) create mode 100644 fast_llm/models/ssm/external/15B_hybrid.ipynb diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb new file mode 100644 index 000000000..ad4773c69 --- /dev/null +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -0,0 +1,516 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import gc\n", + "\n", + "import click\n", + "import torch\n", + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "from transformers import MistralForCausalLM\n", + "\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Slam 15B upcycled" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Lead the weights of https://huggingface.co/ServiceNow-AI/Slam-15B-Upcycled/ into Thiked modeling, it shoudl work" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/home/toolkit/dev/fml-ops/__oo_playground\")\n", + "from results_analysis.results_loader import ResultsLoader\n", + "layer_importance_path = \"/mnt/evaluations/training_evaluation/model_runs/lm_eval_runner/apriel_ssm_importance/\"\n", + "results_loader = ResultsLoader(layer_importance_path)\n", + "\n", + "results_loader.deserialize_results()\n", + "results_df = results_loader.to_df()\n", + "results_df[\"layer_index\"] = results_df.apply(lambda row: int(row[\"model_name_sanitized\"].split(\"_\")[-1] if \"layers_\" in row[\"model_name_sanitized\"] else -1), axis=1)\n", + "results_df = results_df[results_df[\"metric\"] == \"acc_norm\"]\n", + "columns_to_keep = [\"layer_index\", \"metric_value\"]\n", + "results_df = results_df[columns_to_keep]\n", + "layer_importance = results_df.groupby(\"layer_index\").mean()\n", + "layer_importance = layer_importance.sort_values(by=\"metric_value\", ascending=False).reset_index()\n", + "layer_importance = layer_importance[layer_importance[\"layer_index\"]!= -1]\n", + "layer_importance = list(layer_importance[\"layer_index\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[22,\n", + " 25,\n", + " 20,\n", + " 31,\n", + " 29,\n", + " 46,\n", + " 23,\n", + " 26,\n", + " 33,\n", + " 24,\n", + " 47,\n", + " 27,\n", + " 21,\n", + " 41,\n", + " 17,\n", + " 18,\n", + " 34,\n", + " 42,\n", + " 44,\n", + " 30,\n", + " 16,\n", + " 8,\n", + " 43,\n", + " 35,\n", + " 19,\n", + " 38,\n", + " 15,\n", + " 28,\n", + " 32,\n", + " 45,\n", + " 37,\n", + " 40,\n", + " 7,\n", + " 36,\n", + " 13,\n", + " 10,\n", + " 5,\n", + " 39,\n", + " 6,\n", + " 14,\n", + " 4,\n", + " 12,\n", + " 9,\n", + " 48,\n", + " 1,\n", + " 3,\n", + " 11,\n", + " 49,\n", + " 0]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "n_ssm = 25\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "\n", + "for i in range(n_ssm):\n", + " hybrid_block_layout[layer_importance[i]] = \"m2d\"\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg = {\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 32,\n", + " \"n_qk_heads\": 32,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 32 * 128\n", + " }\n", + ")\n", + "model_hybrid = AprielSSMHybridForCausalLM(config_hybrid)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a model of type llama to instantiate a model of type mistral. This is not supported for all configurations of models and can yield errors.\n", + "Loading checkpoint shards: 0%| | 0/4 [00:00 2\u001b[0m model_base \u001b[38;5;241m=\u001b[39m \u001b[43mMistralForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_base\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m model_hybrid\u001b[38;5;241m.\u001b[39mload_state_dict(model_base\u001b[38;5;241m.\u001b[39mstate_dict(), strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:279\u001b[0m, in \u001b[0;36mrestore_default_torch_dtype.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m old_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mget_default_dtype()\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 281\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(old_dtype)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:4342\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 4336\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_autoset_attn_implementation(\n\u001b[1;32m 4337\u001b[0m config, use_flash_attention_2\u001b[38;5;241m=\u001b[39muse_flash_attention_2, torch_dtype\u001b[38;5;241m=\u001b[39mtorch_dtype, device_map\u001b[38;5;241m=\u001b[39mdevice_map\n\u001b[1;32m 4338\u001b[0m )\n\u001b[1;32m 4340\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(model_init_context):\n\u001b[1;32m 4341\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[0;32m-> 4342\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4344\u001b[0m \u001b[38;5;66;03m# Make sure to tie the weights correctly\u001b[39;00m\n\u001b[1;32m 4345\u001b[0m model\u001b[38;5;241m.\u001b[39mtie_weights()\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/models/mistral/modeling_mistral.py:729\u001b[0m, in \u001b[0;36mMistralForCausalLM.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, config):\n\u001b[1;32m 728\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(config)\n\u001b[0;32m--> 729\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[43mMistralModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 730\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvocab_size \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mvocab_size\n\u001b[1;32m 731\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(config\u001b[38;5;241m.\u001b[39mhidden_size, config\u001b[38;5;241m.\u001b[39mvocab_size, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/models/mistral/modeling_mistral.py:440\u001b[0m, in \u001b[0;36mMistralModel.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_idx \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpad_token_id\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvocab_size \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mvocab_size\n\u001b[0;32m--> 440\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEmbedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvocab_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mModuleList(\n\u001b[1;32m 442\u001b[0m [MistralDecoderLayer(config, layer_idx) \u001b[38;5;28;01mfor\u001b[39;00m layer_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(config\u001b[38;5;241m.\u001b[39mnum_hidden_layers)]\n\u001b[1;32m 443\u001b[0m )\n\u001b[1;32m 444\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm \u001b[38;5;241m=\u001b[39m MistralRMSNorm(config\u001b[38;5;241m.\u001b[39mhidden_size, eps\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mrms_norm_eps)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/sparse.py:144\u001b[0m, in \u001b[0;36mEmbedding.__init__\u001b[0;34m(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, _freeze, device, dtype)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale_grad_by_freq \u001b[38;5;241m=\u001b[39m scale_grad_by_freq\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweight \u001b[38;5;241m=\u001b[39m Parameter(\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mempty\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_embeddings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfactory_kwargs\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 145\u001b[0m requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m _freeze)\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_parameters()\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 2684354560 bytes. Error code 12 (Cannot allocate memory)" + ] + } + ], + "source": [ + "path_base = path_thinker\n", + "model_base = MistralForCausalLM.from_pretrained(path_base)\n", + "\n", + "model_hybrid.load_state_dict(model_base.state_dict(), strict=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_32h_init_rand\") # 1 ssm\n", + "# model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_0th_32h_init_rand\") # 1 ssm\n", + "model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_interleaved_ssm_starting0th_32h_init_rand\") # 1 ssm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From d4a04f2e0fbbca4d64f6606990d22f683734158d Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Jun 2025 14:13:36 +0000 Subject: [PATCH 106/122] hiddens tate mistral --- .../apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index e15df990f..0506c4554 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -10,6 +10,7 @@ class AprielSSMHybridConfig(MistralConfig): def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 self.ssm_cfg = ssm_cfg or { "d_state": 64, "n_v_heads": 24, From 978c0afb2f7a114c9a083cd7c9fe13c379d923be Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Jun 2025 14:27:31 +0000 Subject: [PATCH 107/122] update transformers --- .../modeling_ssm_hybrid_apriel15b.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 777fd3cfa..704b5beba 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -16,15 +16,9 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.mistral.modeling_mistral import ( - MISTRAL_INPUTS_DOCSTRING, - MistralDecoderLayer, - MistralMLP, - MistralModel, - MistralRMSNorm, -) +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, add_start_docstrings_to_model_forward, can_return_tuple, logging +from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -780,7 +774,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): self.post_init() @can_return_tuple - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, From 91897dbbfc94c90c5d24051db1f0153a0fc28d95 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 5 Jun 2025 12:50:37 +0000 Subject: [PATCH 108/122] inference optim --- .../configuration_ssm_hybrid_apriel15b.py | 28 +- .../modeling_ssm_hybrid_apriel15b.py | 109 +- .../modeling_ssm_hybrid_apriel15b_cg.py | 1092 +++++++++++++++++ 3 files changed, 1213 insertions(+), 16 deletions(-) create mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index 0506c4554..84242b7db 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -3,6 +3,18 @@ logger = logging.get_logger(__name__) +ssm_config_default = { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, +} + class AprielSSMHybridConfig(MistralConfig): model_type = "apriel_ssm_thinker_hybrid" @@ -11,14 +23,8 @@ def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 - self.ssm_cfg = ssm_cfg or { - "d_state": 64, - "n_v_heads": 24, - "n_qk_heads": 24, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_conv": 4, - "d_inner": 24 * self.head_dim, # num_heads * head_dim - } + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 704b5beba..9b3a63269 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -444,6 +444,16 @@ def __init__( # out_proj self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + # In __init__, pre-allocate these tensors + self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) + self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) + + self.use_cuda_graph = False + self.cuda_graph = None + self.graph_input = None + self.graph_output = None + self.graph_ssm_state = None + self.graph_conv_state = None @property def d_output(self): @@ -472,9 +482,11 @@ def forward( ssm_state, conv_state = None, None if past_key_value is not None: ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if inference_params is not None and inference_params.seqlen_offset > 0: + cache_position = kwargs.get("cache_position", None) + # if inference_params is not None and inference_params.seqlen_offset > 0: + if cache_position is not None and cache_position[0] > 0: # States are updated inplace - # TODO: make sure inference_params with seqlen_offset are properly initialized + # TODO: make sure using cache_position is correct here u = u.squeeze(1) if len(u.shape) == 3 else u out, _, _ = self.step(u, ssm_state, conv_state) out = out.unsqueeze(1) if len(u.shape) == 2 else out @@ -589,8 +601,8 @@ def step(self, u, ssm_state, conv_state, **kwargs): C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) ssm_state = ssm_state.to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate + ones = self.ones_buffer.to(A_log.device).to(x.dtype) y = selective_state_update( x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=repeat(A_log, "b h -> b h p", p=self.headdim), @@ -690,6 +702,93 @@ def convolutional_step(self, xBC, conv_state): return xBC, conv_state + def enable_cuda_graph(self, batch_size=1): + """Capture CUDA graph for the step function""" + if not torch.cuda.is_available(): + return + + # Pre-allocate tensors with fixed shapes + device = next(self.parameters()).device + self.graph_input = torch.zeros(batch_size, self.d_model, device=device, dtype=torch.float16) + self.graph_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.headdim, self.d_state, device=device, dtype=torch.float16 + ) + self.graph_conv_state = torch.zeros( + batch_size, + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_conv, + device=device, + dtype=torch.float16, + ) + self.graph_output = torch.zeros(batch_size, self.d_model, device=device, dtype=torch.float16) + + # Warmup runs + torch.cuda.synchronize() + for _ in range(3): + self._step_graph_impl(self.graph_input, self.graph_ssm_state, self.graph_conv_state) + torch.cuda.synchronize() + + # Capture graph + self.cuda_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.cuda_graph): + self.graph_output = self._step_graph_impl(self.graph_input, self.graph_ssm_state, self.graph_conv_state) + + self.use_cuda_graph = True + + def _step_graph_impl(self, u, ssm_state, conv_state): + """Graph-compatible version of step function""" + # Same logic as step() but with pre-allocated tensors + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate + ones = self.ones_buffer.to(A_log.device).to(x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out + class AprielSSMDecoderLayer(nn.Module): def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): @@ -914,7 +1013,7 @@ class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config, **kwargs): + def __init__(self, config: AprielSSMHybridConfig, **kwargs): super().__init__(config, **kwargs) self.model = AprielSSMHybridModel(config) self.vocab_size = config.vocab_size diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py new file mode 100644 index 000000000..506ebd7b4 --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py @@ -0,0 +1,1092 @@ +import copy +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + # In __init__, pre-allocate these tensors + self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) + self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + cache_position = kwargs.get("cache_position", None) + # if inference_params is not None and inference_params.seqlen_offset > 0: + if cache_position is not None and cache_position[0] > 0: + # States are updated inplace + # TODO: make sure using cache_position is correct here + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate + ones = self.ones_buffer.to(A_log.device).to(x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + +class AprielHybridIdentity(nn.Module): + def __init__(self, config: AprielSSMHybridConfig): + super().__init__() + self.config = config + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return (hidden_states,) + + +class AprielSSMHybridModel(MistralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + config_copy = copy.deepcopy(config) + config_copy.num_hidden_layers = 0 + super().__init__(config_copy, **kwargs) + self.config = config + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx)) + elif type == "t": + blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "i": + blocks.append(AprielHybridIdentity(config)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielHybridPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return AprielHybridCausalOutput( + loss=loss, + logits=logits, + all_hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] From 93839252b22b88c130bf14d095d17af7e0e1dae9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 6 Jun 2025 16:19:45 +0000 Subject: [PATCH 109/122] modeling --- .../modeling_ssm_hybrid_apriel15b.py | 459 ++++++++++-------- 1 file changed, 262 insertions(+), 197 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 9b3a63269..09c659e01 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -26,6 +26,9 @@ logger = logging.get_logger(__name__) +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + class HybridMambaAttentionStaticCache(Cache): def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): super().__init__() # config, batch_size, max_length, device, dtype) @@ -304,6 +307,59 @@ def reset(self): self.ssm_states.zero_() +# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer +# class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): +# """ +# A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache +# (which has a constant shape regardless of seq_len). + +# This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` +# and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor +# For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, +# while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). +# For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), +# while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, +# and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. +# """ + +# def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): +# config.layers_block_type = config.hybrid_block_layout +# super().__init__(config, batch_size, dtype, device) +# self.has_previous_state = False # only used by mamba +# conv_kernel_size = config.ssm_cfg["d_conv"] +# ssm_state_size = config.ssm_cfg["d_state"] +# intermediate_size = config.ssm_cfg["d_inner"] +# self.n_qk_heads = config.ssm_cfg["n_qk_heads"] +# assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" +# self.head_d = intermediate_size // self.n_qk_heads +# self.layers_block_type = config.hybrid_block_layout + +# self.conv_states = [] +# self.ssm_states = [] +# self.transformer_layers = [] +# for i in range(config.num_hidden_layers): +# if self.layers_block_type[i] == "m2d": +# self.conv_states += [ +# torch.zeros( +# batch_size, +# conv_kernel_size, +# intermediate_size + 2 * self.n_qk_heads * ssm_state_size, +# device=device, +# dtype=dtype, +# ).transpose(1, 2) +# ] +# self.ssm_states += [ +# torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) +# ] +# else: +# self.conv_states += [torch.tensor([[]] * batch_size, device=device)] +# self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] +# self.transformer_layers.append(i) + +# self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] +# self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + @dataclass class AprielHybridCausalOutput(ModelOutput): """Custom output class for MambaLMHeadModel.""" @@ -361,6 +417,17 @@ def materialize_mixer(A_log, B, C, D): return T +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + # This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py class DiscreteMamba2(nn.Module): def __init__( @@ -448,13 +515,6 @@ def __init__( self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) - self.use_cuda_graph = False - self.cuda_graph = None - self.graph_input = None - self.graph_output = None - self.graph_ssm_state = None - self.graph_conv_state = None - @property def d_output(self): return self.d_model @@ -466,103 +526,212 @@ def state_to_tensor(self): def forward( self, u, - return_mixer_matrix=False, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - inference_params=None, + attention_mask: Optional[torch.Tensor] = None, + return_mixer_matrix=False, **kwargs, ): """ u: (B, L, D) Returns: same shape as u + For later refference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modeling_bamba.py """ - outputs = {} - # assert state is None + assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" + cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = u.shape - + u = apply_mask_to_padding_states(u, attention_mask) ssm_state, conv_state = None, None - if past_key_value is not None: - ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - cache_position = kwargs.get("cache_position", None) - # if inference_params is not None and inference_params.seqlen_offset > 0: - if cache_position is not None and cache_position[0] > 0: - # States are updated inplace - # TODO: make sure using cache_position is correct here - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _, _ = self.step(u, ssm_state, conv_state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if ssm_state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, + use_precomputed_states = ( + past_key_value is not None + and past_key_value.has_previous_state + and seqlen == 1 + and past_key_value.conv_states[self.layer_idx].shape[0] + == past_key_value.ssm_states[self.layer_idx].shape[0] + == batch + and cache_position is not None + and cache_position[0] > 0 ) + if use_precomputed_states: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + else: + outputs = {} + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) - if ssm_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(ssm_state is not None), - ) + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + # def forward_original( + # self, + # u, + # return_mixer_matrix=False, + # past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + # inference_params=None, + # **kwargs, + # ): + # """ + # u: (B, L, D) + # Returns: same shape as u + # """ + # outputs = {} + # # assert state is None + # batch, seqlen, dim = u.shape + + # ssm_state, conv_state = None, None + # if past_key_value is not None: + # ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + # cache_position = kwargs.get("cache_position", None) + # # if inference_params is not None and inference_params.seqlen_offset > 0: + # if cache_position is not None and cache_position[0] > 0: + # # States are updated inplace + # # TODO: make sure using cache_position is correct here + # u = u.squeeze(1) if len(u.shape) == 3 else u + # out, _, _ = self.step(u, ssm_state, conv_state) + # out = out.unsqueeze(1) if len(u.shape) == 2 else out + # return {"hidden_states": out} + + # # Hacky way to initialize state during inference + # chunk_size = self.chunk_size if ssm_state is None else seqlen + + # # Pad input to nearest multiple of chunklen + # padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + # u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # # Project input + # xBCzA_log = self.in_proj(u) + # xBC, z, A_log = torch.split( + # xBCzA_log, + # [ + # self.d_inner + 2 * self.n_qk_heads * self.d_state, + # self.d_inner, + # self.n_v_heads, + # ], + # dim=-1, + # ) - if ssm_state is not None: - y, ssm_state_update = result - ssm_state.copy_(ssm_state_update) - else: - y = result + # if ssm_state is not None: + # # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + # xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + # conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # # Convolutional layer + # xBC = self.convolutional_forward(xBC, padded_len) + + # x, B, C = torch.split( + # xBC, + # [ + # self.d_inner, + # self.n_qk_heads * self.d_state, + # self.n_qk_heads * self.d_state, + # ], + # dim=-1, + # ) + + # x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + # B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + # C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # # SSM forward + # result = mamba_chunk_scan_combined( + # x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + # dt=A_log, + # dt_softplus=True, + # A=-torch.ones(self.n_v_heads, device=A_log.device), + # B=B, + # C=C, + # chunk_size=chunk_size, + # # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + # return_final_states=(ssm_state is not None), + # ) - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") + # if ssm_state is not None: + # y, ssm_state_update = result + # ssm_state.copy_(ssm_state_update) + # else: + # y = result - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] + # Du = torch.einsum("h,blhp->blhp", self.D, x) + # y = rearrange(y + Du, "b l h p -> b l (h p)") - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs + # # Norm and gate + # out = self.out_proj(y * F.silu(z + self.z_bias)) + # outputs["hidden_states"] = out[:, :seqlen, :] + + # if return_mixer_matrix: + # outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + # return outputs def step(self, u, ssm_state, conv_state, **kwargs): """ @@ -584,7 +753,9 @@ def step(self, u, ssm_state, conv_state, **kwargs): ) xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - conv_state.copy_(conv_state_new) # update state in place + if conv_state_new is not None: + raise NotImplementedError("Should not end up here snce only support fast path.") + # conv_state.copy_(conv_state_new) # update state in place, only for slow pass x, B, C = torch.split( xBC, @@ -623,29 +794,6 @@ def step(self, u, ssm_state, conv_state, **kwargs): return out, ssm_state, conv_state - # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - # device = self.in_proj.weight.device - # # conv_state: - # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - # conv_state = torch.zeros( - # batch_size, - # self.d_conv, - # self.conv1d.weight.shape[0], - # device=device, - # dtype=conv_dtype, - # ).transpose(1, 2) - # # ssm_state: - # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - # ssm_state = torch.zeros( - # batch_size, - # self.n_v_heads, - # self.headdim, - # self.d_state, - # device=device, - # dtype=ssm_dtype, - # ) - # return {"conv": conv_state, "ssm": ssm_state} - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): """ conv_state: (batch, d_conv, conv1d.weight.shape[0]) @@ -692,6 +840,7 @@ def convolutional_step(self, xBC, conv_state): self.conv1d.bias, self.activation if self.activation != "identity" else None, ) + return xBC, None else: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) conv_state[:, :, -1] = xBC @@ -700,94 +849,7 @@ def convolutional_step(self, xBC, conv_state): xBC = xBC + self.conv1d.bias xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - return xBC, conv_state - - def enable_cuda_graph(self, batch_size=1): - """Capture CUDA graph for the step function""" - if not torch.cuda.is_available(): - return - - # Pre-allocate tensors with fixed shapes - device = next(self.parameters()).device - self.graph_input = torch.zeros(batch_size, self.d_model, device=device, dtype=torch.float16) - self.graph_ssm_state = torch.zeros( - batch_size, self.n_qk_heads, self.headdim, self.d_state, device=device, dtype=torch.float16 - ) - self.graph_conv_state = torch.zeros( - batch_size, - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_conv, - device=device, - dtype=torch.float16, - ) - self.graph_output = torch.zeros(batch_size, self.d_model, device=device, dtype=torch.float16) - - # Warmup runs - torch.cuda.synchronize() - for _ in range(3): - self._step_graph_impl(self.graph_input, self.graph_ssm_state, self.graph_conv_state) - torch.cuda.synchronize() - - # Capture graph - self.cuda_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.cuda_graph): - self.graph_output = self._step_graph_impl(self.graph_input, self.graph_ssm_state, self.graph_conv_state) - - self.use_cuda_graph = True - - def _step_graph_impl(self, u, ssm_state, conv_state): - """Graph-compatible version of step function""" - # Same logic as step() but with pre-allocated tensors - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - conv_state.copy_(conv_state_new) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - ssm_state = ssm_state.to(x.dtype) - zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate - ones = self.ones_buffer.to(A_log.device).to(x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=ssm_state, # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out + return xBC, conv_state class AprielSSMDecoderLayer(nn.Module): @@ -879,7 +941,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -971,6 +1033,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, From 4b643b794743f8b2c601e2d98aaf9a6cb38f6300 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 6 Jun 2025 16:30:42 +0000 Subject: [PATCH 110/122] modeling --- .../modeling_ssm_hybrid_apriel15b.py | 64 ++++--------------- 1 file changed, 13 insertions(+), 51 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 09c659e01..03e379e1d 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -302,62 +302,24 @@ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) return self.ssm_states[layer_idx] + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + def reset(self): self.conv_states.zero_() self.ssm_states.zero_() + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") -# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer -# class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): -# """ -# A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache -# (which has a constant shape regardless of seq_len). - -# This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` -# and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor -# For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, -# while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). -# For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), -# while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, -# and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. -# """ - -# def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): -# config.layers_block_type = config.hybrid_block_layout -# super().__init__(config, batch_size, dtype, device) -# self.has_previous_state = False # only used by mamba -# conv_kernel_size = config.ssm_cfg["d_conv"] -# ssm_state_size = config.ssm_cfg["d_state"] -# intermediate_size = config.ssm_cfg["d_inner"] -# self.n_qk_heads = config.ssm_cfg["n_qk_heads"] -# assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" -# self.head_d = intermediate_size // self.n_qk_heads -# self.layers_block_type = config.hybrid_block_layout - -# self.conv_states = [] -# self.ssm_states = [] -# self.transformer_layers = [] -# for i in range(config.num_hidden_layers): -# if self.layers_block_type[i] == "m2d": -# self.conv_states += [ -# torch.zeros( -# batch_size, -# conv_kernel_size, -# intermediate_size + 2 * self.n_qk_heads * ssm_state_size, -# device=device, -# dtype=dtype, -# ).transpose(1, 2) -# ] -# self.ssm_states += [ -# torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) -# ] -# else: -# self.conv_states += [torch.tensor([[]] * batch_size, device=device)] -# self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] -# self.transformer_layers.append(i) - -# self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] -# self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") @dataclass From 0021ab565248dc18c9665523876ee2cee17644ad Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 6 Jun 2025 20:40:57 +0000 Subject: [PATCH 111/122] evalchemy --- fast_llm/models/ssm/external/eval/run_evalchemy.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 fast_llm/models/ssm/external/eval/run_evalchemy.py diff --git a/fast_llm/models/ssm/external/eval/run_evalchemy.py b/fast_llm/models/ssm/external/eval/run_evalchemy.py new file mode 100644 index 000000000..1cbb5b4da --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_evalchemy.py @@ -0,0 +1,9 @@ +from eval.eval import cli_evaluate +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybrid15bSSMWrapper, + AprielHybridSSMWrapper, + AprielSSMWrapper, +) + +if __name__ == "__main__": + cli_evaluate() From cb9a845a120507c7bb3434741e07a624cbe69ba0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 9 Jun 2025 17:08:22 +0000 Subject: [PATCH 112/122] tokenizer --- fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 37c0311b6..9dc04ff5a 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -128,7 +128,7 @@ def __init__(self, pretrained, **kwargs) -> None: super().__init__( pretrained=pretrained, backend=kwargs.pop("backend", "causal"), - tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Instruct-2407/"), max_length=kwargs.pop("max_length", 4096), **kwargs, ) From 2424b85bde70f1ee22b832ebedd4e8d2b8cc066b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 9 Jun 2025 19:04:52 +0000 Subject: [PATCH 113/122] tokenizer --- fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 9dc04ff5a..de1b03842 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -128,7 +128,7 @@ def __init__(self, pretrained, **kwargs) -> None: super().__init__( pretrained=pretrained, backend=kwargs.pop("backend", "causal"), - tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Instruct-2407/"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker"), max_length=kwargs.pop("max_length", 4096), **kwargs, ) From e308a2c31cecb9be87bd0264bc92e4a59245cebb Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 9 Jun 2025 23:56:55 +0000 Subject: [PATCH 114/122] skip causal_conv1d --- fast_llm/layers/ssm/discrete_mamba2.py | 36 ++++++++++++++++++++------ 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d4f9f84d3..83f34bc70 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,7 +1,6 @@ import logging import math -import causal_conv1d import einops import mamba_ssm.ops.triton.ssd_combined import torch @@ -14,6 +13,13 @@ logger = logging.getLogger(__name__) +try: + import causal_conv1d +except ImportError: + # this is needed since we cannot use causal_conv1d on B200 GPUs for now + logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead") + causal_conv1d = None + """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ @@ -92,7 +98,6 @@ def __init__( else 0.0 ) - # Convolutional layer self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), init_method=init_uniform_( @@ -229,10 +234,25 @@ def forward(self, hidden_states, kwargs): def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - xBC = causal_conv1d.causal_conv1d_fn( - xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, - ).transpose(1, 2) + if causal_conv1d is None or self.activation_name not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act( + torch.nn.functional.conv1d( + xBC.transpose(1, 2), + self.conv1d_weight, + bias=self.conv1d_bias, + groups=self.conv1d_weight.shape[0], + padding=self.conv_kernel_size - 1, + )[..., :padded_len].transpose(1, 2) + ) + else: + xBC = causal_conv1d.causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_bias, + activation=None if self.activation_name == "identity" else self.activation_name, + ).transpose(1, 2) return xBC From 3193db9fa36013bb8c2baf4c511a6cfc9b63fe87 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 10 Jun 2025 00:07:45 +0000 Subject: [PATCH 115/122] causal_conv1d optional --- setup.cfg | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2e3b549fc..8b686f507 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ CORE = safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.2.post1 - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm==2.2.4 # Required for some optional features and tools. @@ -44,6 +44,8 @@ OPTIONAL = # Miscellanous requests>=2.32.3 tqdm>=4.66.3 + # For causal_conv1d + causal_conv1d>=1.4.0 DEV = # Pre-commit git hook From 0f45d4a0b410f39096e6995d4783e8cc6006c8aa Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 14:57:24 +0000 Subject: [PATCH 116/122] 5b notebook --- fast_llm/models/ssm/external/5B_hybrid.ipynb | 374 ++++++ .../modeling_ssm_hybrid_apriel15b_cg.py | 1092 ----------------- 2 files changed, 374 insertions(+), 1092 deletions(-) create mode 100644 fast_llm/models/ssm/external/5B_hybrid.ipynb delete mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb new file mode 100644 index 000000000..a79e1b176 --- /dev/null +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "import torch\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "fast_llm_path = \"/home/toolkit/dev/Fast-LLM\"\n", + "\n", + "# add fast_llm to the python path\n", + "import sys\n", + "sys.path.append(fast_llm_path)\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "base = 0.612615\n", + "layer_scores = {\n", + " \"22\": 0.607389,\n", + " \"24\": 0.603498,\n", + " \"19\": 0.597907,\n", + " \"27\": 0.597173,\n", + " \"20\": 0.590442,\n", + " \"5\": 0.578949,\n", + " \"4\": 0.576852,\n", + " \"9\": 0.576484,\n", + " \"23\": 0.574833,\n", + " \"7\": 0.571860,\n", + " \"8\": 0.571790,\n", + " \"6\": 0.571614,\n", + " \"2\": 0.571330,\n", + " \"26\": 0.570205,\n", + " \"11\": 0.567128,\n", + " \"14\": 0.566175,\n", + " \"15\": 0.566076,\n", + " \"3\": 0.562861,\n", + " \"1\": 0.560154,\n", + " \"13\": 0.559304,\n", + " \"16\": 0.559017,\n", + " \"10\": 0.558789,\n", + " \"12\": 0.555186,\n", + " \"17\": 0.554236,\n", + " \"25\": 0.549215,\n", + " \"18\": 0.537257,\n", + " \"0\": 0.233085,\n", + "}\n", + "layer_scores = {k: base - v for k, v in layer_scores.items()}\n", + "layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('22', 0.005226000000000064),\n", + " ('24', 0.009117000000000042),\n", + " ('19', 0.014708000000000054),\n", + " ('27', 0.015442000000000067),\n", + " ('20', 0.022173),\n", + " ('5', 0.033665999999999974),\n", + " ('4', 0.03576299999999999),\n", + " ('9', 0.036131000000000024),\n", + " ('23', 0.03778199999999998),\n", + " ('7', 0.040754999999999986),\n", + " ('8', 0.040825),\n", + " ('6', 0.041001000000000065),\n", + " ('2', 0.041285000000000016),\n", + " ('26', 0.04241000000000006),\n", + " ('11', 0.045487000000000055),\n", + " ('14', 0.04644000000000004),\n", + " ('15', 0.046539),\n", + " ('3', 0.049754000000000076),\n", + " ('1', 0.05246099999999998),\n", + " ('13', 0.053311),\n", + " ('16', 0.053598000000000035),\n", + " ('10', 0.05382600000000004),\n", + " ('12', 0.05742900000000006),\n", + " ('17', 0.05837900000000007),\n", + " ('25', 0.06340000000000001),\n", + " ('18', 0.07535800000000004),\n", + " ('0', 0.37953000000000003)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importanfce" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create hybrid with any number of SSM layers" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "device = \"cuda\"\n", + "n_hybrid = 1\n", + "\n", + "index_swaped = []\n", + "hybrid_block_layout = [\"t\"] * config.num_hidden_layers\n", + "for i in range(n_hybrid):\n", + " hybrid_block_layout[int(layer_importanfce[i][0])] = \"m2d\"\n", + " index_swaped.append(int(layer_importanfce[i][0]))\n", + "\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrdif_apriel_config.hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridForCausalLM(\n", + " (model): AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-21): 22 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (22): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (23-27): 5 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)\n", + "hybrid_apriel_model.to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A new version of the following files was downloaded from https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct:\n", + "- modeling_apriel.py\n", + ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", + "Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 23.31it/s]\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.11it/s]\n" + ] + } + ], + "source": [ + "\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: ['model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight']\n", + "Unexpected keys: ['model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight']\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save the hybrid model\n", + "output_path = \"/mnt/checkpoints/ssm/iterative_hybrids_5b\"\n", + "assert len(index_swaped) == 1\n", + "layer_swaped = index_swaped[0]\n", + "hybrid_apriel_model.save_pretrained(\n", + " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py deleted file mode 100644 index 506ebd7b4..000000000 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b_cg.py +++ /dev/null @@ -1,1092 +0,0 @@ -import copy -from dataclasses import dataclass -from functools import partial -from typing import Any, Optional, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from torch import nn -from transformers import GenerationMixin -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm -from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging -from transformers.utils.generic import ModelOutput - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig - -logger = logging.get_logger(__name__) - - -class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): - super().__init__() # config, batch_size, max_length, device, dtype) - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - self.batch_size = batch_size - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - self.max_cache_len = config.max_position_embeddings if max_length is None else max_length - - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) - - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - new_layer_conv_state = torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - - new_layer_ssm_state = torch.zeros( - batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype - ) - new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) - new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) - else: - # Attention or MLP layer - new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - self.transformer_layers.append(i) - - # if not is_torchdynamo_compiling(): - # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # new_layer_key_cache = getattr(self, f"key_cache_{i}") - # new_layer_value_cache = getattr(self, f"value_cache_{i}") - # torch._dynamo.mark_static_address(new_layer_key_cache) - # torch._dynamo.mark_static_address(new_layer_value_cache) - # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) - # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) - # torch._dynamo.mark_static_address(new_layer_conv_state) - # torch._dynamo.mark_static_address(new_layer_ssm_state) - # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") - # new_layer_conv_state = getattr(self, f"conv_states_{i}") - - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - self.conv_states.append(new_layer_conv_state) - self.ssm_states.append(new_layer_ssm_state) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - cache_position = cache_kwargs.get("cache_position") - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = None) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx is None: - layer_idx = self.transformer_layers[0] - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__() - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - self.conv_states += [ - torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - ] - self.ssm_states += [ - torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) - ] - else: - # Attention or MLP layer - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -@dataclass -class AprielHybridCausalOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - attention_weights: Optional[torch.FloatTensor] = None - past_key_values: Optional[Cache] = None - - -def segsum(x): - """More stable segment sum calculation.""" - # [1, 2, 3] - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) - x = x.masked_fill(~mask, 0) - # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] - x_segsum = torch.cumsum(x, dim=-2) - # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def materialize_mixer(A_log, B, C, D): - """ - Since the transfer matrix will be equated to the attention matrix, - we need to support the form: torch.matmul(attn_weights, value_states). - Thus, y = torch.matmul(T, X) - Arguments: - A_log: (batch, length, n_heads) - B: (batch, length, n_heads, d_state) - C: (batch, length, n_heads, d_state) - Return: - T: (batch, n_heads, length, length) - """ - batch_size, length, n_heads, d_state = B.shape - assert A_log.shape == (batch_size, length, n_heads) - assert B.shape == C.shape == (batch_size, length, n_heads, d_state) - - # Compute: - A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") - powers = torch.exp(segsum(A_log)) - T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) - - # Add D: - if D is not None: - T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) - - T = rearrange(T, "b h z l -> b h l z") - return T - - -# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - # In __init__, pre-allocate these tensors - self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) - self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def forward( - self, - u, - return_mixer_matrix=False, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - inference_params=None, - **kwargs, - ): - """ - u: (B, L, D) - Returns: same shape as u - """ - outputs = {} - # assert state is None - batch, seqlen, dim = u.shape - - ssm_state, conv_state = None, None - if past_key_value is not None: - ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - cache_position = kwargs.get("cache_position", None) - # if inference_params is not None and inference_params.seqlen_offset > 0: - if cache_position is not None and cache_position[0] > 0: - # States are updated inplace - # TODO: make sure using cache_position is correct here - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _, _ = self.step(u, ssm_state, conv_state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if ssm_state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if ssm_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(ssm_state is not None), - ) - - if ssm_state is not None: - y, ssm_state_update = result - ssm_state.copy_(ssm_state_update) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, ssm_state, conv_state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - conv_state.copy_(conv_state_new) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - ssm_state = ssm_state.to(x.dtype) - zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate - ones = self.ones_buffer.to(A_log.device).to(x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=ssm_state, # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, ssm_state, conv_state - - # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - # device = self.in_proj.weight.device - # # conv_state: - # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - # conv_state = torch.zeros( - # batch_size, - # self.d_conv, - # self.conv1d.weight.shape[0], - # device=device, - # dtype=conv_dtype, - # ).transpose(1, 2) - # # ssm_state: - # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - # ssm_state = torch.zeros( - # batch_size, - # self.n_v_heads, - # self.headdim, - # self.d_state, - # device=device, - # dtype=ssm_dtype, - # ) - # return {"conv": conv_state, "ssm": ssm_state} - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - # if self.layer_idx not in inference_params.ssm_states: - # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - # batch_size, inference_params.max_seqlen, dtype=torch.float32 - # ) - # Get states - ssm_states = inference_params.ssm_states[self.layer_idx] - conv_states = inference_params.conv_states[self.layer_idx] - if initialize_states: - ssm_states.zero_() - conv_states.zero_() - return ssm_states, conv_states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - -class AprielSSMDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.hidden_size = config.hidden_size - - self.mixer = DiscreteMamba2( - d_model=config.hidden_size, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, hidden_states: torch.Tensor, **kwargs - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - outputs = {} - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - mixer_outputs = self.mixer( - hidden_states, - **kwargs, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - # outputs["hidden_states"] = hidden_states - outputs = (hidden_states,) - - return outputs - - -class AprielHybridIdentity(nn.Module): - def __init__(self, config: AprielSSMHybridConfig): - super().__init__() - self.config = config - - def forward(self, hidden_states: torch.Tensor, **kwargs): - return (hidden_states,) - - -class AprielSSMHybridModel(MistralModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] - Args: - config: AprielSSMHybridConfig - """ - - def __init__(self, config: AprielSSMHybridConfig, **kwargs): - config_copy = copy.deepcopy(config) - config_copy.num_hidden_layers = 0 - super().__init__(config_copy, **kwargs) - self.config = config - blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") - for layer_idx, type in enumerate(config.hybrid_block_layout): - if type == "m2d": - blocks.append(AprielSSMDecoderLayer(config, layer_idx)) - elif type == "t": - blocks.append(MistralDecoderLayer(config, layer_idx)) - elif type == "i": - blocks.append(AprielHybridIdentity(config)) - else: - raise ValueError(f"Invalid block type: {type}") - self.layers = nn.ModuleList(blocks) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed - # if use_cache and past_key_values is None: - # past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class AprielHybridPreTrainedModel(PreTrainedModel): - config_class = AprielSSMHybridConfig - base_model_prefix = "model" - _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MistralRMSNorm): - module.weight.data.fill_(1.0) - - -class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - - def __init__(self, config: AprielSSMHybridConfig, **kwargs): - super().__init__(config, **kwargs) - self.model = AprielSSMHybridModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - output_router_logits=False, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if not empty_past_kv: - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "output_router_logits": output_router_logits, - # "logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - } - ) - return model_inputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return AprielHybridCausalOutput( - loss=loss, - logits=logits, - all_hidden_states=outputs.hidden_states, - past_key_values=outputs.past_key_values, - ) - - -__all__ = [ - "AprielSSMHybridForCausalLM", - "AprielSSMHybridModel", - "AprielSSMPreTrainedModel", -] From fc6ed64891c43132d265dca2f7d48e46146e0d2e Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 15:02:18 +0000 Subject: [PATCH 117/122] nvm --- fast_llm/models/ssm/external/5B_hybrid.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb index a79e1b176..3560b825d 100644 --- a/fast_llm/models/ssm/external/5B_hybrid.ipynb +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -329,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ From 5fb3a3966a193e0f2916073f1db2f94bc0beeb2c Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 15:02:30 +0000 Subject: [PATCH 118/122] nvm --- fast_llm/models/ssm/external/5B_hybrid.ipynb | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb index 3560b825d..2924faa82 100644 --- a/fast_llm/models/ssm/external/5B_hybrid.ipynb +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -149,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -185,7 +185,7 @@ " 't']" ] }, - "execution_count": 8, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -339,15 +339,9 @@ "layer_swaped = index_swaped[0]\n", "hybrid_apriel_model.save_pretrained(\n", " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", - " )\n" + " )\n", + "print(f\"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\")\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 386f3c1583c2f4c352ca7a2be0b694572fff3bcb Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 20:01:29 +0000 Subject: [PATCH 119/122] clean --- .../ssm/external/make_hybrid_checkpoint.py | 41 ------------------- 1 file changed, 41 deletions(-) delete mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint.py diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py deleted file mode 100644 index a0616ab64..000000000 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py +++ /dev/null @@ -1,41 +0,0 @@ -import gc - -import click -import torch -from transformers import AutoConfig, AutoModelForCausalLM - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -@click.command() -@click.option("--identity_index", type=int, required=True) -@click.option("--save_dir", type=str, required=True) -def main(identity_index: int, save_dir: str): - checkpoint = "ServiceNow-AI/Apriel-Nemotron-15b-Thinker" - config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) - - hybrid_block_layout = ["t"] * config.num_hidden_layers - if identity_index >= 0: - hybrid_block_layout[identity_index] = "i" - - hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(), hybrid_block_layout=hybrid_block_layout) - hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config) - hybrid_apriel_model.to(dtype=torch.bfloat16).to(device) - - apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True) - apriel_state_dict = apriel_model.state_dict() - hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False) - - hybrid_apriel_model.save_pretrained(save_dir, save_config=True) - torch.cuda.empty_cache() - del hybrid_apriel_model - del apriel_model - del apriel_state_dict - gc.collect() - - -if __name__ == "__main__": - main() From 73284d83d9a508949fb7798fb086e268b8e06c3b Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 20:07:30 +0000 Subject: [PATCH 120/122] clean --- fast_llm/models/ssm/external/15B_hybrid.ipynb | 516 ------------------ fast_llm/models/ssm/external/5B_hybrid.ipynb | 368 ------------- 2 files changed, 884 deletions(-) delete mode 100644 fast_llm/models/ssm/external/15B_hybrid.ipynb delete mode 100644 fast_llm/models/ssm/external/5B_hybrid.ipynb diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb deleted file mode 100644 index ad4773c69..000000000 --- a/fast_llm/models/ssm/external/15B_hybrid.ipynb +++ /dev/null @@ -1,516 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import gc\n", - "\n", - "import click\n", - "import torch\n", - "from transformers import AutoConfig, AutoModelForCausalLM\n", - "from transformers import MistralForCausalLM\n", - "\n", - "from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Slam 15B upcycled" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " Lead the weights of https://huggingface.co/ServiceNow-AI/Slam-15B-Upcycled/ into Thiked modeling, it shoudl work" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append(\"/home/toolkit/dev/fml-ops/__oo_playground\")\n", - "from results_analysis.results_loader import ResultsLoader\n", - "layer_importance_path = \"/mnt/evaluations/training_evaluation/model_runs/lm_eval_runner/apriel_ssm_importance/\"\n", - "results_loader = ResultsLoader(layer_importance_path)\n", - "\n", - "results_loader.deserialize_results()\n", - "results_df = results_loader.to_df()\n", - "results_df[\"layer_index\"] = results_df.apply(lambda row: int(row[\"model_name_sanitized\"].split(\"_\")[-1] if \"layers_\" in row[\"model_name_sanitized\"] else -1), axis=1)\n", - "results_df = results_df[results_df[\"metric\"] == \"acc_norm\"]\n", - "columns_to_keep = [\"layer_index\", \"metric_value\"]\n", - "results_df = results_df[columns_to_keep]\n", - "layer_importance = results_df.groupby(\"layer_index\").mean()\n", - "layer_importance = layer_importance.sort_values(by=\"metric_value\", ascending=False).reset_index()\n", - "layer_importance = layer_importance[layer_importance[\"layer_index\"]!= -1]\n", - "layer_importance = list(layer_importance[\"layer_index\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[22,\n", - " 25,\n", - " 20,\n", - " 31,\n", - " 29,\n", - " 46,\n", - " 23,\n", - " 26,\n", - " 33,\n", - " 24,\n", - " 47,\n", - " 27,\n", - " 21,\n", - " 41,\n", - " 17,\n", - " 18,\n", - " 34,\n", - " 42,\n", - " 44,\n", - " 30,\n", - " 16,\n", - " 8,\n", - " 43,\n", - " 35,\n", - " 19,\n", - " 38,\n", - " 15,\n", - " 28,\n", - " 32,\n", - " 45,\n", - " 37,\n", - " 40,\n", - " 7,\n", - " 36,\n", - " 13,\n", - " 10,\n", - " 5,\n", - " 39,\n", - " 6,\n", - " 14,\n", - " 4,\n", - " 12,\n", - " 9,\n", - " 48,\n", - " 1,\n", - " 3,\n", - " 11,\n", - " 49,\n", - " 0]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer_importance" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", - "n_ssm = 25\n", - "\n", - "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", - "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", - "\n", - "for i in range(n_ssm):\n", - " hybrid_block_layout[layer_importance[i]] = \"m2d\"\n", - "\n", - "config_hybrid = AprielSSMHybridConfig(\n", - " **config_thinker.to_dict(),\n", - " hybrid_block_layout=hybrid_block_layout,\n", - " ssm_cfg = {\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 32,\n", - " \"n_qk_heads\": 32,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_conv\": 4,\n", - " \"d_inner\": 32 * 128\n", - " }\n", - ")\n", - "model_hybrid = AprielSSMHybridForCausalLM(config_hybrid)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['t',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 'm2d',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 't',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 't',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 'm2d',\n", - " 't',\n", - " 'm2d',\n", - " 'm2d',\n", - " 't',\n", - " 't']" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hybrid_block_layout" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "You are using a model of type llama to instantiate a model of type mistral. This is not supported for all configurations of models and can yield errors.\n", - "Loading checkpoint shards: 0%| | 0/4 [00:00 2\u001b[0m model_base \u001b[38;5;241m=\u001b[39m \u001b[43mMistralForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_base\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m model_hybrid\u001b[38;5;241m.\u001b[39mload_state_dict(model_base\u001b[38;5;241m.\u001b[39mstate_dict(), strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:279\u001b[0m, in \u001b[0;36mrestore_default_torch_dtype.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m old_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mget_default_dtype()\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 281\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(old_dtype)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:4342\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 4336\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_autoset_attn_implementation(\n\u001b[1;32m 4337\u001b[0m config, use_flash_attention_2\u001b[38;5;241m=\u001b[39muse_flash_attention_2, torch_dtype\u001b[38;5;241m=\u001b[39mtorch_dtype, device_map\u001b[38;5;241m=\u001b[39mdevice_map\n\u001b[1;32m 4338\u001b[0m )\n\u001b[1;32m 4340\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ContextManagers(model_init_context):\n\u001b[1;32m 4341\u001b[0m \u001b[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001b[39;00m\n\u001b[0;32m-> 4342\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4344\u001b[0m \u001b[38;5;66;03m# Make sure to tie the weights correctly\u001b[39;00m\n\u001b[1;32m 4345\u001b[0m model\u001b[38;5;241m.\u001b[39mtie_weights()\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/models/mistral/modeling_mistral.py:729\u001b[0m, in \u001b[0;36mMistralForCausalLM.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, config):\n\u001b[1;32m 728\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(config)\n\u001b[0;32m--> 729\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[43mMistralModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 730\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvocab_size \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mvocab_size\n\u001b[1;32m 731\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(config\u001b[38;5;241m.\u001b[39mhidden_size, config\u001b[38;5;241m.\u001b[39mvocab_size, bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/models/mistral/modeling_mistral.py:440\u001b[0m, in \u001b[0;36mMistralModel.__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_idx \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpad_token_id\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvocab_size \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mvocab_size\n\u001b[0;32m--> 440\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mEmbedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvocab_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mModuleList(\n\u001b[1;32m 442\u001b[0m [MistralDecoderLayer(config, layer_idx) \u001b[38;5;28;01mfor\u001b[39;00m layer_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(config\u001b[38;5;241m.\u001b[39mnum_hidden_layers)]\n\u001b[1;32m 443\u001b[0m )\n\u001b[1;32m 444\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm \u001b[38;5;241m=\u001b[39m MistralRMSNorm(config\u001b[38;5;241m.\u001b[39mhidden_size, eps\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mrms_norm_eps)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/sparse.py:144\u001b[0m, in \u001b[0;36mEmbedding.__init__\u001b[0;34m(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, _freeze, device, dtype)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale_grad_by_freq \u001b[38;5;241m=\u001b[39m scale_grad_by_freq\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweight \u001b[38;5;241m=\u001b[39m Parameter(\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mempty\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_embeddings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfactory_kwargs\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 145\u001b[0m requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m _freeze)\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_parameters()\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 2684354560 bytes. Error code 12 (Cannot allocate memory)" - ] - } - ], - "source": [ - "path_base = path_thinker\n", - "model_base = MistralForCausalLM.from_pretrained(path_base)\n", - "\n", - "model_hybrid.load_state_dict(model_base.state_dict(), strict=False)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_32h_init_rand\") # 1 ssm\n", - "# model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_0th_32h_init_rand\") # 1 ssm\n", - "model_hybrid.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_interleaved_ssm_starting0th_32h_init_rand\") # 1 ssm\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "fast_llm", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb deleted file mode 100644 index 2924faa82..000000000 --- a/fast_llm/models/ssm/external/5B_hybrid.ipynb +++ /dev/null @@ -1,368 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "import torch\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "\n", - "fast_llm_path = \"/home/toolkit/dev/Fast-LLM\"\n", - "\n", - "# add fast_llm to the python path\n", - "import sys\n", - "sys.path.append(fast_llm_path)\n", - "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "base = 0.612615\n", - "layer_scores = {\n", - " \"22\": 0.607389,\n", - " \"24\": 0.603498,\n", - " \"19\": 0.597907,\n", - " \"27\": 0.597173,\n", - " \"20\": 0.590442,\n", - " \"5\": 0.578949,\n", - " \"4\": 0.576852,\n", - " \"9\": 0.576484,\n", - " \"23\": 0.574833,\n", - " \"7\": 0.571860,\n", - " \"8\": 0.571790,\n", - " \"6\": 0.571614,\n", - " \"2\": 0.571330,\n", - " \"26\": 0.570205,\n", - " \"11\": 0.567128,\n", - " \"14\": 0.566175,\n", - " \"15\": 0.566076,\n", - " \"3\": 0.562861,\n", - " \"1\": 0.560154,\n", - " \"13\": 0.559304,\n", - " \"16\": 0.559017,\n", - " \"10\": 0.558789,\n", - " \"12\": 0.555186,\n", - " \"17\": 0.554236,\n", - " \"25\": 0.549215,\n", - " \"18\": 0.537257,\n", - " \"0\": 0.233085,\n", - "}\n", - "layer_scores = {k: base - v for k, v in layer_scores.items()}\n", - "layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('22', 0.005226000000000064),\n", - " ('24', 0.009117000000000042),\n", - " ('19', 0.014708000000000054),\n", - " ('27', 0.015442000000000067),\n", - " ('20', 0.022173),\n", - " ('5', 0.033665999999999974),\n", - " ('4', 0.03576299999999999),\n", - " ('9', 0.036131000000000024),\n", - " ('23', 0.03778199999999998),\n", - " ('7', 0.040754999999999986),\n", - " ('8', 0.040825),\n", - " ('6', 0.041001000000000065),\n", - " ('2', 0.041285000000000016),\n", - " ('26', 0.04241000000000006),\n", - " ('11', 0.045487000000000055),\n", - " ('14', 0.04644000000000004),\n", - " ('15', 0.046539),\n", - " ('3', 0.049754000000000076),\n", - " ('1', 0.05246099999999998),\n", - " ('13', 0.053311),\n", - " ('16', 0.053598000000000035),\n", - " ('10', 0.05382600000000004),\n", - " ('12', 0.05742900000000006),\n", - " ('17', 0.05837900000000007),\n", - " ('25', 0.06340000000000001),\n", - " ('18', 0.07535800000000004),\n", - " ('0', 0.37953000000000003)]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer_importanfce" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create hybrid with any number of SSM layers" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "device = \"cuda\"\n", - "n_hybrid = 1\n", - "\n", - "index_swaped = []\n", - "hybrid_block_layout = [\"t\"] * config.num_hidden_layers\n", - "for i in range(n_hybrid):\n", - " hybrid_block_layout[int(layer_importanfce[i][0])] = \"m2d\"\n", - " index_swaped.append(int(layer_importanfce[i][0]))\n", - "\n", - "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", - " hybrid_block_layout=hybrid_block_layout,\n", - " ssm_cfg={\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 24,\n", - " \"n_qk_heads\": 24,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_inner\": 24 * 128, # num_heads * head_dim\n", - " })" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['t',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 'm2d',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't']" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hybrdif_apriel_config.hybrid_block_layout" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMHybridForCausalLM(\n", - " (model): AprielSSMHybridModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-21): 22 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (22): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (23-27): 5 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)\n", - "hybrid_apriel_model.to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "A new version of the following files was downloaded from https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct:\n", - "- modeling_apriel.py\n", - ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", - "Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 23.31it/s]\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 4.11it/s]\n" - ] - } - ], - "source": [ - "\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Missing keys: ['model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight']\n", - "Unexpected keys: ['model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight']\n" - ] - } - ], - "source": [ - "# unexpected will contain keys from the SSM layers we added\n", - "print(\"Missing keys:\", missing)\n", - "# unexpected will contain keys from the transformer layers we replaced\n", - "print(\"Unexpected keys:\", unexpected)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# save the hybrid model\n", - "output_path = \"/mnt/checkpoints/ssm/iterative_hybrids_5b\"\n", - "assert len(index_swaped) == 1\n", - "layer_swaped = index_swaped[0]\n", - "hybrid_apriel_model.save_pretrained(\n", - " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", - " )\n", - "print(f\"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "fast_llm", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From f56964ec99ec8fd7aabd9cb3da7aa2deec10c86a Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 20:52:10 +0000 Subject: [PATCH 121/122] test imports --- tests/test_mtp.py | 2 +- tests/test_ssms.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index 9f1939f14..71c55e0fc 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -21,7 +21,7 @@ from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.models.ssm.model import HybridSSMBaseModel -except ImportError: +except Exception: MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( None, None, diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 5a3cedf1a..f3eb92617 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -24,7 +24,7 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -except ImportError: +except Exception: MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( None, None, From 9c50f9cca41cb2b84844923bc2cab0421e9c242f Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 11 Jun 2025 21:10:20 +0000 Subject: [PATCH 122/122] nvm --- fast_llm/layers/ssm/mamba_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 04dfdbc7b..7d0ee48a4 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -106,7 +106,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, - layer_lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True