diff --git a/.dockerignore b/.dockerignore index 0ed5480a2..500fbe11c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,6 +6,7 @@ !setup.cfg !Megatron-LM !fast_llm +!fast_llm_external_models !examples !tools !tests diff --git a/Dockerfile b/Dockerfile index 71f59fffe..526026fa4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,6 +33,7 @@ RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://gith RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ +COPY --chmod=777 ./fast_llm_external_models/__init__.py fast_llm_external_models/ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ @@ -44,4 +45,5 @@ COPY --chmod=777 ./Megatron-LM Megatron-LM COPY --chmod=777 ./examples examples COPY --chmod=777 ./tests tests COPY --chmod=777 ./tools tools +COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 35a324db0..6f42d8b6a 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = len(self._model.config.base_model.decoder) # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md index e6bda8031..655fa29c0 100644 --- a/docs/recipes/generate.md +++ b/docs/recipes/generate.md @@ -21,12 +21,12 @@ Below is a step-by-step example of how to generate text using a Fast-LLM model c import huggingface_hub from transformers import AutoTokenizer from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM # Specify model and configuration model = "HuggingFaceTB/SmolLM2-135M-Instruct" -checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat +checkpoint_format = LlamaCheckpointFormat max_new_tokens = 50 # Download model checkpoint from the Hugging Face Hub to a local directory diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 924bfba51..4b7fdd968 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -27,32 +27,33 @@ optimizer: beta_2: 0.95 model: base_model: - transformer: - mixer: - type: attention - rotary: - type: default - theta: 10000 - heads: 32 - head_groups: 8 - head_size: 128 - add_linear_biases: false - window_size: 4096 - dropout: 0.0 - mlp: - intermediate_size: 14336 - add_linear_biases: false - gated: true - activation: silu - normalization: - type: rms_norm - epsilon: 1.0e-05 - num_layers: 32 - hidden_size: 4096 - dropout: 0.0 embeddings_layer: + hidden_size: 4096 vocab_size: 32000 dropout: 0.0 + decoder: + block: + mixer: + type: attention + rotary: + type: default + theta: 10000 + heads: 32 + head_groups: 8 + head_size: 128 + add_linear_biases: false + window_size: 4096 + dropout: 0.0 + mlp: + intermediate_size: 14336 + add_linear_biases: false + gated: true + activation: silu + normalization: + type: rms_norm + epsilon: 1.0e-05 + dropout: 0.0 + num_blocks: 32 output_layer: tied_weight: false normalization: diff --git a/fast_llm/config.py b/fast_llm/config.py index 3352f3570..4d3858fd7 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -783,11 +783,12 @@ def _from_dict( try: actual_cls = cls.get_subclass(default.get("type")) - if actual_cls is not None and actual_cls is not cls: - return actual_cls._from_dict(default, strict=strict, flat=flat) except KeyError: - # Postpone error to validation. - pass + # Try to postpone error to validation. + actual_cls = cls + + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 9de5ac2cc..0a3f8d1ce 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -1,5 +1,4 @@ import abc -import dataclasses import typing import torch @@ -78,17 +77,6 @@ def setup(self, distributed: Distributed) -> None: layer.setup(distributed) -@dataclasses.dataclass() -class LossDef: - # A name for the loss - name: str - formatted_name: str - # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. - count: int = 1 - dtype: torch.dtype = torch.float32 - - class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( @@ -135,11 +123,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - @property - @abc.abstractmethod - def loss_defs(self) -> list[LossDef]: - pass - def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: assert name not in self._reference_models assert not self._is_setup diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 2b55d782e..78fafea34 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -3,6 +3,7 @@ import typing from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: @@ -75,3 +76,14 @@ class ResourceUsageConfig: forward: int = 1 # Number of backward passes. Typically 1 for training, 0 for inference. backward: int = 1 + + +@dataclasses.dataclass() +class LossDef: + # A name for the loss + name: str + formatted_name: str + # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. + # TODO: Allow variable count? Would need a reduction across PP devices. + count: int = 1 + dtype: DataType = DataType.float32 diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6a..886c706c1 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1,5 +1,4 @@ import abc -import dataclasses import logging import pathlib import typing @@ -7,7 +6,7 @@ import torch from fast_llm import __version__ -from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.config import Config from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler @@ -19,124 +18,12 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass(kw_only=True) -class ParamConverter(abc.ABC): - fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format. - export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format. - - @abc.abstractmethod - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - pass - - @abc.abstractmethod - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - pass - - -@dataclasses.dataclass(kw_only=True) -class RenameParamConverter(ParamConverter): - ignore_missing: bool = False - default_value: typing.Any = None - - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return fast_llm_values - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - if self.ignore_missing: - if export_values[0] == MISSING: - logger.warning( - "The configuration parameter `%s=%s` is ignored during conversion as it is not present in the checkpoint.", - self.export_names[0], - export_values[0], - ) - return (self.default_value,) - return export_values - - -# def __repr__(self): -# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})" - - -@dataclasses.dataclass(kw_only=True) -class ConstantImportParamConverter(ParamConverter): - fast_llm_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 0) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - Assert.eq(fast_llm_values[0], self.fast_llm_value) - return () - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.fast_llm_value,) - - -@dataclasses.dataclass(kw_only=True) -class ConstantExportParamConverter(ParamConverter): - export_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.export_value,) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - Assert.eq(export_values[0], self.export_value) - return () - - -@dataclasses.dataclass(kw_only=True) -class IgnoreImportParamConverter(ParamConverter): - ignore_export_value: typing.Any = MISSING - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (MISSING,) - - def import_params(self, export_values): - if export_values[0] not in (self.ignore_export_value, MISSING): - logger.warning( - "The configuration parameter `%s=%s` is ignored during conversion." - " If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration.", - self.export_names[0], - export_values[0], - ) - return () - - -@dataclasses.dataclass(kw_only=True) -class MappedConfigParamConverter(ParamConverter): - fast_llm_value: typing.Callable[[typing.Any], typing.Any] = lambda x: x - export_value: typing.Callable[[typing.Any], typing.Any] = lambda x: x - - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.export_value(fast_llm_values[0]),) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (self.fast_llm_value(export_values[0]),) - - class WeightConverter: def __init__( self, fast_llm_name: str | tuple[str, ...], export_name: str | tuple[str, ...], - config: BaseModelConfig | None = None, + config: Config | None = None, ): self.fast_llm_name: tuple[str, ...] = (fast_llm_name,) if isinstance(fast_llm_name, str) else fast_llm_name self.export_name: tuple[str, ...] = (export_name,) if isinstance(export_name, str) else export_name @@ -216,7 +103,6 @@ def import_weight( class ExternalStateDictCheckpointHandler(StateDictCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] - _config_converters: list[ParamConverter] def __init__(self, model: "FastLLMModel"): super().__init__(model) @@ -239,20 +125,14 @@ def __init__(self, model: "FastLLMModel"): @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: - imported_model_config = cls._import_config(cls._load_config(config.path)) return CheckpointMetadata( fast_llm_version=__version__, model=cls._model_class, format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + config=cls._import_config(cls._load_config(config.path)), shards=["weights"], ) - @classmethod - @abc.abstractmethod - def _create_config_converters(cls) -> list[ParamConverter]: - pass - @abc.abstractmethod def _create_weight_converters(self) -> list[WeightConverter]: pass @@ -263,51 +143,15 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: pass @classmethod - def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: - # TODO v0.3: not used in this class - exported_config = {} - for converter in cls._get_config_converters(): - try: - values = converter.export_params( - tuple( - cls._get_fast_llm_attribute(config, fast_llm_name) - for fast_llm_name in converter.fast_llm_names - ) - ) - for export_name, value in zip(converter.export_names, values, strict=True): - if value is not MISSING: - set_nested_dict_value(exported_config, export_name, value) - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return exported_config # Noqa + @abc.abstractmethod + def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: + # TODO: not used in this class + pass @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa - kwargs = {} - for converter in cls._get_config_converters(): - try: - values = () - for export_name in converter.export_names: - try: - value = get_nested_dict_value(config, export_name) - except KeyError: - value = MISSING - values = values + (value,) - values = converter.import_params(values) - for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): - if value is MISSING: - # Missing values need to be handled in dedicated converters, - # because implicit / default values may not match. - # TODO: Different behavior from other uses of MISSING. Use different tag? - raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") - if fast_llm_name in kwargs: - raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") - kwargs[fast_llm_name] = value - except Exception as e: - raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) - - return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + @abc.abstractmethod + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: + pass def _convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool @@ -343,12 +187,6 @@ def _convert_state_dict( return out_state_dict - @classmethod - def _get_config_converters(cls) -> list[ParamConverter]: - if not hasattr(cls, "_config_converters"): - cls._config_converters = cls._create_config_converters() - return cls._config_converters - @staticmethod def _get_fast_llm_attribute(config: BaseModelConfig, name: str | tuple[str, ...]) -> typing.Any: if isinstance(name, str): @@ -374,6 +212,6 @@ def get_handler_class(cls, format: str) -> type[ExternalStateDictCheckpointHandl # TODO: load_metadata??? @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: # TODO: ??? return cls.handler_map[config["model_type"]]._import_config(config) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005f..e5d14711d 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -6,21 +6,47 @@ import safetensors import torch -from transformers.configuration_utils import PretrainedConfig +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig -from fast_llm.engine.checkpoint.external import ( - ConstantExportParamConverter, - ExternalStateDictCheckpointHandler, - ParamConverter, - logger, -) -from fast_llm.engine.multi_stage.config import CheckpointMetadata +from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler, WeightConverter, logger +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert +from fast_llm.utils import Assert, safe_merge_dicts + +if typing.TYPE_CHECKING: + import transformers + + +class HuggingFaceBaseModelConverter: + @classmethod + @abc.abstractmethod + def import_config(cls, config: dict) -> dict: + pass + + @classmethod + @abc.abstractmethod + def export_config(cls, config: BaseModelConfig) -> dict: + pass + + @classmethod + @abc.abstractmethod + def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + pass class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): + architecture: typing.ClassVar[str] + base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] + + @classmethod + @abc.abstractmethod + def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: + pass + + @classmethod + def get_model_files(cls) -> tuple[str | None, str | None, str | None]: + return None, None, None @classmethod def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: @@ -35,7 +61,7 @@ def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadat ) def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata) -> dict: - huggingface_config = self._export_config(self._model.config.base_model) + huggingface_config = self._export_config(self._model.config) self._save_config(config.path, huggingface_config) return { "fast_llm_metadata": metadata.to_dict(), @@ -49,6 +75,20 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) super().load(config) + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: + super().save(config, metadata) + # Copy the modeling files to the output directory + modeling_file, configuration_file, generation_utils_file = self.get_model_files() + if configuration_file is not None: + shutil.copy(configuration_file, config.path) + if modeling_file is not None: + shutil.copy(modeling_file, config.path) + if generation_utils_file is not None: + shutil.copy(generation_utils_file, config.path) + gen_config = pathlib.Path(generation_utils_file).parent / "generation_config.json" + if gen_config.exists(): + shutil.copy(gen_config, config.path) + @classmethod def get_huggingface_model_type(self) -> str: # We assume the two names match, but derived classes can make it different. @@ -59,28 +99,37 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name - @classmethod - @abc.abstractmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return [ - ConstantExportParamConverter( - export_names=(("model_type",),), export_value=cls.get_huggingface_model_type() - ) - ] - + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: - import transformers - - config = transformers.AutoConfig.from_pretrained(directory).to_dict() + config = cls.get_transformers_configuration_class().from_pretrained(directory).to_dict() Assert.eq(config["model_type"], cls.get_huggingface_model_type()) return config @classmethod def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - import transformers + cls.get_transformers_configuration_class().from_dict(config).save_pretrained(directory) + + @classmethod + def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + cls.base_model_converter_class.export_config(config.base_model), + { + "model_type": cls.get_huggingface_model_type(), + "architecture": cls.architecture, + }, + ) + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + Assert.eq(config["architecture"], cls.architecture) + return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) - transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) + def _create_weight_converters( + self, + ) -> list[WeightConverter]: + return self.base_model_converter_class.get_converters(self._model.config.base_model) def _load_weights( self, config: CheckpointLoadConfig, device @@ -123,39 +172,3 @@ def _load_weights( yield from torch.load(path) else: raise NotImplementedError(f"Unknown file format for {path}") - - -class CustomModelingExportMixin: - """ - Mixin class for HuggingfaceStateDictCheckpointHandler to handle custom modeling files. - """ - - modeling_file: typing.ClassVar[str] - configuration_file: typing.ClassVar[str] - configuration_cls: typing.ClassVar[type[PretrainedConfig]] - generation_utils_file: str | None = None - - # Use custom config instead of relying on the transformers library - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - config = cls.configuration_cls.from_pretrained(directory).to_dict() - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - cls.configuration_cls.from_dict(config).save_pretrained(directory) - - def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: - super().save(config, metadata) - self._copy_modeling_files(config) - - def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: - # Copy the modeling files to the output directory - shutil.copy(self.modeling_file, config.path) - shutil.copy(self.configuration_file, config.path) - if self.generation_utils_file: - shutil.copy(self.generation_utils_file, config.path) - gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" - if gen_config.exists(): - shutil.copy(gen_config, config.path) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 33e4d654f..d5202a90f 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -116,7 +116,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index b38056adb..e48fdb88b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -468,9 +468,7 @@ def get_state_tensor_iterator( ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) - for shard_index, (stage, shard) in enumerate( - zip(self._stages_on_device.values(), shard_split, strict=True) - ): + for shard_index, (stage, shard) in enumerate(zip(self._stages_owned.values(), shard_split, strict=True)): for name, tensor in stage._export_shard( shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type ): # noqa diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 21ecbe476..dbdd035a4 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,10 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._loss_definitions = { + loss_definition.name: loss_definition + for loss_definition in self._multi_stage.base_model.config.get_loss_definitions() + } def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup @@ -148,7 +151,7 @@ def run_step( context = BatchContext( iteration=iteration, schedule=schedule, - losses={loss_def: [] for loss_def in self._loss_defs}, + losses={loss_def: [] for loss_def in self._loss_definitions}, metrics=metrics, ) context.data_iterator = self._preprocess_data(context, data_iterator, preprocessed) @@ -280,11 +283,13 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: for name, losses in context.losses.items(): if losses or self._distributed.pipeline_group: if losses: - reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count + reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_definitions[name].count if self._distributed.data_group: all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) + reduced_loss = torch.zeros( + [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device + ) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 305f3ebf3..7db9b1fc3 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -149,7 +149,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() if not self._is_evaluation_only: steps_per_split = { diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index bbd70ede4..9a940f4cb 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -10,9 +10,9 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -49,7 +49,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): +class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): """ A self-attention layer. """ diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 868d6ba77..214bb7729 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -9,8 +9,9 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig +from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 5187ebfdc..773cce87e 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,3 @@ -import abc import functools import logging import typing @@ -6,14 +5,12 @@ import torch from fast_llm.config import Config, Configurable -from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer, Module from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -22,6 +19,10 @@ class DebugLayer: + """ + A debugging utility for blocks. + """ + # TODO: Move elsewhere? def __init__(self, module: torch.nn.Module): self._module = module @@ -92,9 +93,9 @@ def __call__[ ) -class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): +class BaseBlock[ConfigType: Config](Configurable[ConfigType], Module): """ - Base class for blocks, mixers, MLPs, etc. + Base class for blocks and block-like layers (mlp, mixers, etc.). """ def __init__( @@ -118,25 +119,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c raise NotImplementedError() -class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): +class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): """ - Base class for mixer and MLP modules. - """ - - @abc.abstractmethod - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - pass - - -class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): - """ - A transformer-like decoder base block with abstract mixer. + Base class for actual blocks, i.e., base blocks that are also `Layers`. """ def __init__( @@ -149,103 +134,5 @@ def __init__( peft: PeftConfig | None, return_input: bool = False, ): - super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - ) - # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. - self.mixer = self._config.mixer.get_layer( - self._distributed_config, - self._hidden_dim, - self._lr_scale, - peft=peft, - ) - - self.mlp = self._config.mlp.get_layer( - self._distributed_config, - self._hidden_dim, - self._lr_scale, - peft=peft, - ) - - def setup(self, distributed: Distributed) -> None: - super().setup(distributed) - self.mixer.setup(distributed) - self.mlp.setup(distributed) - - @torch.compile - def _bias_dropout_add( - self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor - ) -> torch.Tensor: - if bias is not None: - input_ = input_ + bias - return residual + torch.dropout(input_, self._config.dropout, self.training) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - dims = kwargs[BlockKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) - generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator - if self._debug.enabled: - self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) - fw_input = input_ - hidden_states = self.norm_1(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mixer(hidden_states, kwargs) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "mixer output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states = self.norm_2(input_) - if self._debug.enabled: - self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug.enabled: - self._debug( - hidden_states if bias is None else hidden_states + bias, - "MLP output", - kwargs[BlockKwargs.hidden_dims], - kwargs, - ) - with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug.enabled: - self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) - if self._return_input: - hidden_states = torch.stack((fw_input, hidden_states), dim=0) - return hidden_states - - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # TODO: Add marginal compute? (normalization, bias_dropout_add) - return sum( - ( - self.mixer.get_compute_usage(input_, kwargs, config), - self.mlp.get_compute_usage(input_, kwargs, config), - ) - ) + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_input = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd42bccf9..7df2705fa 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,19 +1,19 @@ +import abc +import collections +import functools import typing +import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import BlockLayer - - -# TODO: Generalize these beyond language models? (Ex. vision) + from fast_llm.layers.block.block import Block class BlockDimNames: @@ -41,13 +41,12 @@ class BlockKwargs: @config_class() -class BlockLayerConfig(BaseModelConfig): +class BaseBlockConfig(BaseModelConfig): """ - A common class for mixers and mlps, which have the same interface. + Base configuration class for blocks and block-like layers (mlp, mixers, etc.). """ _abstract = True - block: "BlockConfig" = Field(init=False) lr_scale: float | None = Field( default=None, @@ -56,34 +55,57 @@ class BlockLayerConfig(BaseModelConfig): hint=FieldHint.feature, ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return [] + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return [] + + +@config_class(registry=True) +class BlockConfig(BaseBlockConfig): + """ + Base configuration class for actual blocks, i.e., base blocks that are also `Layers`. + """ + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.decoder.config import DecoderBlockConfig + + # Default subclass. + return DecoderBlockConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @property - def layer_class(self) -> "type[BlockLayer]": + def layer_class(self) -> "type[Block]": raise NotImplementedError() - def get_layer( + def get_block( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ) -> "BlockLayer": + return_input: bool = False, + ) -> "Block": return self.layer_class( self, distributed_config, hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_input=return_input, ) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Move to actual layers? - return [] - @config_class(registry=True) -class MLPBaseConfig(BlockLayerConfig): - _abstract = True - +class BlockSequenceConfig(BaseModelConfig): @classmethod def _from_dict( cls, @@ -91,91 +113,105 @@ def _from_dict( strict: bool = True, flat: bool = False, ) -> typing.Self: - if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.block.mlp.config import MLPConfig - + if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: # Default subclass. - return MLPConfig._from_dict(default, strict, flat) + return FixedBlockSequenceConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) + @abc.abstractmethod + def __len__(self) -> int: + pass -@config_class(registry=True) -class MixerConfig(BlockLayerConfig): - """ - Base config class for mixers. - """ + @abc.abstractmethod + def __getitem__(self, index: int) -> BlockConfig: + pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.attention.config import AttentionConfig + @abc.abstractmethod + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + pass - # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return [] -@config_class() -class BlockConfig(BaseModelConfig): +@config_class(dynamic_type={BlockSequenceConfig: "fixed"}) +class FixedBlockSequenceConfig(BlockSequenceConfig): _abstract = False - mixer: MixerConfig = Field() - mlp: MLPBaseConfig = Field() - # TODO: Review names - normalization: NormalizationConfig = Field( - desc="Configuration for the block normalization layers.", + block: BlockConfig = Field( + desc="Common configuration for all the blocks.", hint=FieldHint.architecture, ) - lr_scale: float | None = Field( - default=None, - desc="Scaling factor for the layer learning rate." - " Combines multiplicatively with the scale set by the parent and child layers, if applicable.", - hint=FieldHint.feature, - ) - # TODO: Review names - dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - # TODO: Move these, not specific to a single block. - num_layers: int = Field( + num_blocks: int = Field( default=12, desc="Number of blocks in the model.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + + def __len__(self) -> int: + return self.num_blocks + + def __getitem__(self, index: int) -> BlockConfig: + return self.block + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + # TODO: Prevent name conflicts in preprocessed kwargs. + return self.block.get_preprocessors(distributed_config) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.block.get_loss_definitions(count=count * self.num_blocks) + + +@config_class(dynamic_type={BlockSequenceConfig: "pattern"}) +class PatternBlockSequenceConfig(BlockSequenceConfig): + _abstract = False + blocks: dict[str, BlockConfig] = Field() + pattern: list[str] = Field( + default=None, + desc="The name of each block (key in `blocks`) in the repeated pattern.", + hint=FieldHint.architecture, + ) + num_blocks: int = Field( + default=12, + desc="Number of blocks in the model.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + valid=check_field(Assert.geq, 0), ) - def get_layer( - self, - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None = None, - return_input: bool = False, - ): - from fast_llm.layers.block.block import Block + def _validate(self): + if not self.blocks: + raise ValueError("No block configuration provided") + if not self.pattern: + raise ValueError("No block pattern provided") + used_blocks = set(self.pattern) + available_blocks = set(self.blocks) + if missing := used_blocks - available_blocks: + raise ValueError(f"The following blocks are present in the pattern but undefined: {missing}") + if extra := available_blocks - used_blocks: + raise warnings.warn(f"The following blocks are defined but unused: {extra}") - return Block( - self, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - return_input=return_input, - ) + super()._validate() + + def __len__(self) -> int: + return self.num_blocks + + def __getitem__(self, index: int) -> BlockConfig: + return self.blocks[self.expanded_pattern[index]] + + @functools.cached_property + def expanded_pattern(self) -> list[str]: + return (self.pattern * (self.num_blocks // len(self.pattern) + 1))[: self.num_blocks] def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Move to actual layers? - return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) + # TODO: Prevent name conflicts in preprocessed kwargs. + return sum((block.get_preprocessors(distributed_config) for block in self.blocks.values()), []) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # TODO: Prevent name conflicts. + return sum( + ( + self.blocks[name].get_loss_definitions(count=count * count_) + for name, count_ in collections.Counter(self.expanded_pattern).items() + ), + [], + ) diff --git a/fast_llm/layers/block/mlp/__init__.py b/fast_llm/layers/block/sequence.py similarity index 100% rename from fast_llm/layers/block/mlp/__init__.py rename to fast_llm/layers/block/sequence.py diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index 2ed97ae66..e7c6d9e92 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -19,16 +19,16 @@ class LinearBaseConfig(Config): Configuration for a linear-like layer without bias. """ + weight: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." " Combines multiplicatively with the scale set by the parent layer and individual parameters, if applicable.", hint=FieldHint.feature, ) - weight: ParameterConfig = Field( - desc="Initialization configuration for the weight.", - hint=FieldHint.feature, - ) @config_class() diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 3401e61be..33cbd9768 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -4,7 +4,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -81,7 +81,10 @@ class LayerNormalizationBaseConfig(NormalizationConfig): Common configuration for layer norm and rms norm """ - # TODO: Rename to normalization_epsilon + weight: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) epsilon: float = Field( default=1e-5, desc="Regularizer for the division.", @@ -98,13 +101,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): desc="The implementation to use for the normalization layer.", hint=FieldHint.performance, ) - # TODO: Rename to normalization_init_range - initialization_range: float = Field( - default=0.0, - desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", - hint=FieldHint.testing, - valid=check_field(Assert.geq, 0), - ) @property @abc.abstractmethod @@ -128,6 +124,10 @@ def _from_dict( @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): + bias: ParameterConfig = Field( + desc="Configuration for the weight.", + hint=FieldHint.architecture, + ) _abstract = False @property diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 0dc7b9589..d0a5ab151 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -3,7 +3,7 @@ import torch from fast_llm.config import Configurable -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig @@ -15,7 +15,7 @@ NormalizationImplementation, RMSNormalizationConfig, ) -from fast_llm.tensor import ParameterMeta, accumulate_gradient +from fast_llm.tensor import accumulate_gradient from fast_llm.utils import Assert try: @@ -205,23 +205,17 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: raise NotImplementedError(implementation) - if self.config.initialization_range: - mean = 0 if self.zero_centered else 1 - weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) - else: - weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ - - self.weight = ParameterMeta.from_dims( + self.weight = self._config.weight.get_parameter( (hidden_dim,), - init_method=weight_init_method, - weight_decay=False, + default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, lr_scale=self._lr_scale, + peft=None, ) - self.bias = ParameterMeta.from_dims( + self.bias = self._config.bias.get_parameter( (hidden_dim,), - init_method=init_zeros_, - weight_decay=False, + default_initialization=init_zeros_, lr_scale=self._lr_scale, + peft=None, ) self._normalized_shape = self.weight.shape @@ -277,17 +271,11 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: raise NotImplementedError(implementation) - if self.config.initialization_range: - mean = 0 if self.zero_centered else 1 - weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) - else: - weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ - - self.weight = ParameterMeta.from_dims( + self.weight = self._config.weight.get_parameter( (hidden_dim,), - init_method=weight_init_method, - weight_decay=False, - lr_scale=lr_scale, + default_initialization=init_zeros_ if self._config.zero_centered else init_ones_, + lr_scale=self._lr_scale, + peft=None, ) self._normalized_shape = self.weight.shape diff --git a/fast_llm/models/custom/__init__.py b/fast_llm/layers/decoder/__init__.py similarity index 100% rename from fast_llm/models/custom/__init__.py rename to fast_llm/layers/decoder/__init__.py diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py new file mode 100644 index 000000000..ba4c370c2 --- /dev/null +++ b/fast_llm/layers/decoder/block.py @@ -0,0 +1,152 @@ +import abc +import logging +import typing + +import torch + +from fast_llm.config import Config +from fast_llm.core.distributed import set_generator +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): + """ + Base class for mixer and MLP modules. + """ + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + +class DecoderBlock[ConfigType: DecoderBlockConfig](Block[ConfigType]): + """ + A transformer-like decoder base block with abstract mixer. + """ + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input: bool = return_input + # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale + self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + self.mixer = self._config.mixer.get_layer( + self._distributed_config, + self._hidden_dim, + self._lr_scale, + peft=peft, + ) + + self.mlp = self._config.mlp.get_layer( + self._distributed_config, + self._hidden_dim, + self._lr_scale, + peft=peft, + ) + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + self.mixer.setup(distributed) + self.mlp.setup(distributed) + + @torch.compile + def _bias_dropout_add( + self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor + ) -> torch.Tensor: + if bias is not None: + input_ = input_ + bias + return residual + torch.dropout(input_, self._config.dropout, self.training) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + dims = kwargs[BlockKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator + if self._debug.enabled: + self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) + fw_input = input_ + hidden_states = self.norm_1(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mixer(hidden_states, kwargs) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "mixer output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + input_ = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states = self.norm_2(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) + return hidden_states + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (normalization, bias_dropout_add) + return sum( + ( + self.mixer.get_compute_usage(input_, kwargs, config), + self.mlp.get_compute_usage(input_, kwargs, config), + ) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py new file mode 100644 index 000000000..2d8cc71fd --- /dev/null +++ b/fast_llm/layers/decoder/config.py @@ -0,0 +1,111 @@ +import typing + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import LossDef, Preprocessor +from fast_llm.engine.config_utils.parameter import combine_lr_scales +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock + + +@config_class() +class BlockWithBiasConfig(BaseBlockConfig): + """ + A common interface for various blocks and block layers. + """ + + @property + def layer_class(self) -> "type[BlockWithBias]": + raise NotImplementedError() + + def get_layer( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "BlockWithBias": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ) + + +@config_class(registry=True) +class MLPBaseConfig(BlockWithBiasConfig): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.decoder.mlp.config import MLPConfig + + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(registry=True) +class MixerConfig(BlockWithBiasConfig): + """ + Base config class for mixers. + """ + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.attention.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={BlockConfig: "decoder"}) +class DecoderBlockConfig(BlockConfig): + _abstract = False + mixer: MixerConfig = Field() + mlp: MLPBaseConfig = Field() + # TODO: Review names + normalization: NormalizationConfig = Field( + desc="Configuration for the block normalization layers.", + hint=FieldHint.architecture, + ) + # TODO: Review names + dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + @property + def layer_class(self) -> "type[DecoderBlock]": + from fast_llm.layers.decoder.block import DecoderBlock + + return DecoderBlock + + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: + return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/mlp/__init__.py b/fast_llm/layers/decoder/mlp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py similarity index 81% rename from fast_llm/layers/block/mlp/config.py rename to fast_llm/layers/decoder/mlp/config.py index 3d8a9c2bf..100f53740 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,14 +3,15 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import LossDef from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import MLPBaseConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig +from fast_llm.layers.decoder.config import MLPBaseConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - from fast_llm.layers.block.mlp.mlp import MLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mlp import MLP class MLPLossNames: @@ -74,7 +75,7 @@ def _validate(self) -> None: @property def layer_class(self) -> "type[MLP]": - from fast_llm.layers.block.mlp.mlp import MLP + from fast_llm.layers.decoder.mlp.mlp import MLP return MLP @@ -87,10 +88,10 @@ class MoEMLPConfig(MLPConfig): hint=FieldHint.feature, ) experts: int = Field( - default=1, + default=2, desc="Number of MLP experts in a Mixture of Expert (MoE) model", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + valid=check_field(Assert.gt, 1), ) shared_experts: int = Field( default=0, @@ -139,7 +140,7 @@ class MoEMLPConfig(MLPConfig): @property def layer_class(self) -> "type[MixtureOfExpertMLP]": - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.decoder.mlp.mixture_of_experts import MixtureOfExpertMLP return MixtureOfExpertMLP @@ -151,3 +152,23 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_definitions = [] + if self.routing == RoutingType.topk: + loss_definitions.append( + LossDef( + name=MLPLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=1, + ) + ) + if self.z_loss_coefficient: + loss_definitions.append( + LossDef( + name=MLPLossNames.router_z_loss, + formatted_name="router z loss", + count=1, + ) + ) + return loss_definitions diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py similarity index 98% rename from fast_llm/layers/block/mlp/mixture_of_experts.py rename to fast_llm/layers/decoder/mlp/mixture_of_experts.py index 9478dc51c..089fa2dc7 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -13,10 +13,10 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType -from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType +from fast_llm.layers.decoder.mlp.mlp import MLPBase from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py similarity index 96% rename from fast_llm/layers/block/mlp/mlp.py rename to fast_llm/layers/decoder/mlp/mlp.py index c88f766b0..fe4879e73 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -10,13 +10,13 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.tensor import TensorMeta -class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): +class MLPBase[ConfigType: MLPConfig](BlockWithBias[ConfigType]): _config: ConfigType def __init__( diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c15515fb5..849e09aa9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,12 +1,12 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -42,7 +42,7 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelEmbeddingsConfig(BlockLayerConfig): +class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False word_embeddings: ParameterConfig = Field( desc="Configuration for the word embedding (weight).", @@ -52,6 +52,12 @@ class LanguageModelEmbeddingsConfig(BlockLayerConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) + hidden_size: int = Field( + default=1024, + desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", @@ -72,7 +78,7 @@ class LanguageModelEmbeddingsConfig(BlockLayerConfig): ) full_precision_residual: bool = Field( default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + desc="Store the residuals for the model in full precision (`optimization_dtype`).", hint=FieldHint.stability, ) @@ -104,7 +110,7 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr @config_class() -class LanguageModelHeadConfig(BlockLayerConfig): +class LanguageModelHeadConfig(BlockConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -234,7 +240,36 @@ def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Prepr return preprocessors - def get_layer( + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_defs = [] + if self.logit_z_loss: + LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=count) + + if self.enable_dpo: + loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=count)) + + if self.distillation_model is not None: + loss_defs.append( + LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=count) + ) + if self.language_model_loss_factor > 0.0: + loss_defs.append( + LossDef( + name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=count + ) + ) + + for i in range(self.prediction_heads): + loss_defs.append( + LossDef( + name=LanguageModelLossNames.multi_token_prediction_loss(i), + formatted_name=f"language model loss {i}", + count=count, + ) + ) + return loss_defs + + def get_block( self, distributed_config: DistributedConfig, embeddings_config: LanguageModelEmbeddingsConfig, @@ -254,12 +289,49 @@ def get_layer( prediction_distance=prediction_distance, ) + def get_blocks( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + mtp_block_config: BlockConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + blocks = [] + for i in range(self.prediction_heads): + if i > 0: + blocks.append( + mtp_block_config.get_block( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + # The last block only returns the model output. + # The previous blocks return a stack of shared_hidden and transformer_output. + return_input=i < self.prediction_heads - 1, + ) + ) + blocks.append( + self.get_block( + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + prediction_distance=i, + ) + ) + return blocks + +# TODO: `BlockSequenceConfig`? (interface not fully compatible) @config_class() class LanguageModelBaseConfig(BaseModelConfig): # TODO: block - transformer: BlockConfig = Field( - desc="Configuration for the transformer architecture.", + decoder: BlockSequenceConfig = Field( + desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) embeddings_layer: LanguageModelEmbeddingsConfig = Field() @@ -292,9 +364,61 @@ def from_flat_dict( cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") return super().from_flat_dict(default, strict) + def __len__(self) -> int: + return len(self.decoder) + 2 * self.output_layer.prediction_heads + + def __getitem__(self, index: int) -> BlockConfig: + if index <= 0: + Assert.eq(index, 0) + return self.embeddings_layer + elif index <= len(self.decoder): + return self.decoder[index - 1] + else: + # Start at the last decoder layer so all MTP heads are treated similarly. + index - len(self.decoder) + return self.embeddings_layer + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: return ( self.embeddings_layer.get_preprocessors(distributed_config) - + self.transformer.get_preprocessors(distributed_config) + + self.decoder.get_preprocessors(distributed_config) + self.output_layer.get_preprocessors(distributed_config) ) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return ( + self.embeddings_layer.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.output_layer.get_loss_definitions(count) + ) + + def get_blocks(self, distributed_config: DistributedConfig): + hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) + return [ + self.embeddings_layer.get_block( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self.peft, + ), + *[ + self.decoder[i].get_block( + distributed_config, + hidden_dim, + lr_scale=None, + peft=self.peft, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, + ) + for i in range(len(self.decoder)) + ], + *self.output_layer.get_blocks( + distributed_config, + self.embeddings_layer, + self.decoder[len(self.decoder) - 1], + hidden_dim=hidden_dim, + lr_scale=None, + peft=self.peft, + ), + ] diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index b7a780a33..e0661cfa2 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -4,12 +4,11 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta @@ -18,7 +17,7 @@ WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -37,13 +36,17 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_input: bool = False, ): + if return_input: + raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e71512915..ade1144d2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -6,7 +6,6 @@ from torch.distributed import all_reduce from fast_llm.core.ops import split_op -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -16,7 +15,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig @@ -35,7 +34,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -53,13 +52,17 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, + return_input: bool = False, ): + if return_input: + raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -179,7 +182,7 @@ def _forward_backward( if self._sequence_parallel_logits else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) - meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) + meta = TensorMeta.from_dims(tuple(dims), tensor_name="hidden_state", dtype=ln_output.dtype) hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b89b28cd..e541341e5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,12 +1,11 @@ -import enum import math import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.layers.block.config import MixerConfig from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig +from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -17,38 +16,6 @@ from fast_llm.tensor import ParameterMeta -class SSMBlockType(enum.StrEnum): - """ - An enum for the available mamba types for the MLP layer. - """ - - mamba = "m" - mamba2_discrete = "m2d" - mamba2 = "m2" - transformer = "t" - - def get_mixer_class(self): - if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba import Mamba - - return Mamba - elif self == SSMBlockType.mamba2: - from fast_llm.layers.ssm.mamba2 import Mamba2 - - return Mamba2 - elif self == SSMBlockType.mamba2_discrete: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - - return DiscreteMamba2 - else: - raise NotImplementedError(self) - - -class DTInitType(enum.StrEnum): - constant = "constant" - random = "random" - - @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index a7d059781..f014012b2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -9,9 +9,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import DiscreteMamba2Config from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -27,12 +27,13 @@ _mamba_available = False -class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockLayer[ConfigType]): +class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockWithBias[ConfigType]): """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + _config: DiscreteMamba2Config def __init__( self, @@ -104,6 +105,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( convolution_dim, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 5caa1a97c..e77a4468b 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -8,9 +8,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import MambaConfig, init_a, init_dtprojbias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -31,8 +31,9 @@ """ -class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): +class Mamba[ConfigType: MambaConfig](BlockWithBias[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" + _config: MambaConfig def __init__( self, @@ -72,7 +73,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( inner_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), - default_add_bias=False, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, @@ -91,6 +92,7 @@ def __init__( inner_dim, default_weight_initialization=init_normal_(0, (2 / self._config.d_inner) ** 0.5), default_bias_initialization=init_dtprojbias(), + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) @@ -113,7 +115,7 @@ def __init__( inner_dim, hidden_dim, default_weight_initialization=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), - default_add_bias=False, + default_add_bias=self._config.add_linear_biases, lr_scale=self._lr_scale, peft=self._peft, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b48f100db..b0657313d 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -8,9 +8,9 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import Mamba2Config, init_a, init_dtprojbias from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -25,12 +25,13 @@ logger = logging.getLogger(__name__) -class Mamba2[ConfigType: Mamba2Config](BlockLayer[ConfigType]): +class Mamba2[ConfigType: Mamba2Config](BlockWithBias[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ _mixer_name: typing.ClassVar[str] = "mamba_2" + _config: Mamba2Config def __init__( self, @@ -81,6 +82,7 @@ def __init__( self.convolution = self._config.convolution_layer.get_layer( convolution_dim, + default_add_bias=self._config.add_linear_biases, default_activation=ActivationType.silu, lr_scale=self._lr_scale, peft=self._peft, @@ -108,6 +110,7 @@ def __init__( inner_dim, default_weight_initialization=init_uniform_centered_(self._config.dt_rank**-0.5), default_bias_initialization=init_dtprojbias(), + default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, @@ -223,7 +226,7 @@ def forward( c, self.D.float(), z, - delta_bias=self.dt_proj.bias.float(), + delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, ) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 1bc30aeab..931c7f644 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -6,7 +6,7 @@ import torch import torch._dynamo # noqa -from fast_llm.engine.base_model.base_model import LossDef +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.logging import TensorLogs from fast_llm.engine.distributed.config import PhaseType from fast_llm.tensor import TensorMeta diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index a5860096e..322932664 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,8 +2,6 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip +from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip - from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py deleted file mode 100644 index aa304d396..000000000 --- a/fast_llm/models/custom/config.py +++ /dev/null @@ -1,62 +0,0 @@ -import typing - -from fast_llm.config import FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig - -if typing.TYPE_CHECKING: - from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM - from fast_llm.models.custom.model import CustomModel - from fast_llm.models.custom.trainer import CustomTrainer - - -@config_class() -class CustomDataConfig(GPTDataConfig): - # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. - pass - - -@config_class() -class CustomBaseModelConfig(GPTBaseModelConfig): - # TODO: Add custom other base model config parameters, if any. - pass - - -@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"}) -class CustomModelConfig(GPTModelConfig): - # TODO: Add custom model config parameters, if any (typically none). - model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate() - - @classmethod - def get_model_class(cls) -> type["CustomModel"]: - from fast_llm.models.custom.model import CustomModel - - return CustomModel - - @classmethod - def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: - from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM - - return HuggingfaceCustomModelForCausalLM - - -@config_class() -class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate() - - -@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"}) -class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): - # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate() - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["CustomTrainer"]: - from fast_llm.models.custom.trainer import CustomTrainer - - return CustomTrainer diff --git a/fast_llm/models/custom/data.py b/fast_llm/models/custom/data.py deleted file mode 100644 index 45ffd9edb..000000000 --- a/fast_llm/models/custom/data.py +++ /dev/null @@ -1,48 +0,0 @@ -import pathlib -import typing - -from fast_llm.data.data.gpt.data import GPTData -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.models.custom.config import CustomDataConfig - - -class CustomData(GPTData): - # TODO: If needed, inherit from AbstractData instead and re-implement everything. - def __init__( - self, - config: CustomDataConfig, - distributed_config: DistributedConfig, - vocab_size: int, - max_sequence_length: int, - ): - # TODO: Adjust or reimplement. - super().__init__(config, distributed_config, vocab_size, max_sequence_length) - - def setup( - self, - distributed: Distributed, - samples_per_phase: dict[PhaseType, int], - cache_directory: pathlib.Path, - ): - # TODO: Adjust or reimplement. - return super().setup(distributed, samples_per_phase, cache_directory) - - def get_iterator( - self, - batch_config: BatchConfig, - phase: PhaseType, - *, - consumed_samples: int, - num_workers: int, - prefetch_factor: int | None = None, - ) -> typing.Iterator[typing.Any]: - # TODO: Adjust or reimplement. - return super().get_iterator( - batch_config, - phase, - consumed_samples=consumed_samples, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) diff --git a/fast_llm/models/custom/head.py b/fast_llm/models/custom/head.py deleted file mode 100644 index 786e36929..000000000 --- a/fast_llm/models/custom/head.py +++ /dev/null @@ -1,6 +0,0 @@ -from fast_llm.layers.language_model.head import LanguageModelHead - - -class CustomHead(LanguageModelHead): - # TODO: Implement custom parts - pass diff --git a/fast_llm/models/custom/huggingface.py b/fast_llm/models/custom/huggingface.py deleted file mode 100644 index 7db4e73f8..000000000 --- a/fast_llm/models/custom/huggingface.py +++ /dev/null @@ -1,18 +0,0 @@ -from fast_llm.models.custom.config import CustomModelConfig -from fast_llm.models.custom.model import CustomModel -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM - - -class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig): - model_type = "fast_llm_gpt_custom" - model_config_class = CustomModelConfig - fast_llm_config: CustomModelConfig - - -class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM): - # TODO: Implement changes in huggingface interface, if any. - # Ex.: Return predictions instead of logits. - config_class = HuggingfaceCustomModelConfig - config: HuggingfaceCustomModelConfig - model_class = CustomModel - _fast_llm_model: CustomModel diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py deleted file mode 100644 index 3afd88ce1..000000000 --- a/fast_llm/models/custom/model.py +++ /dev/null @@ -1,59 +0,0 @@ -import typing - -import torch - -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import LossDef -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.models.custom.config import CustomBaseModelConfig -from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.model import GPTBaseModel, GPTModel -from fast_llm.tensor import TensorMeta - - -class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - ): - # TODO: Implement / update. - super().__init__(config, distributed_config) - - def _get_head(self, prediction_distance): - return CustomHead( - self._config, - self._distributed_config, - self._hidden_dim, - max(self._config.transformer.num_layers + prediction_distance, 1), - f"Language model head {prediction_distance}", - prediction_distance=prediction_distance, - ) - - def preprocess_meta( - self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType - ) -> list[tuple[TensorMeta, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess_meta(batch_meta, phase) - - def preprocess( - self, - batch: GPTBatch, - preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, - *, - phase: PhaseType, - iteration: int, - metrics: dict | None = None, - ) -> list[tuple[torch.Tensor, dict]]: - # TODO: Adjust or reimplement. - return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - - @property - def loss_defs(self) -> list[LossDef]: - # TODO: Adjust or reimplement. - return super().loss_defs - - -class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/custom/readme.md b/fast_llm/models/custom/readme.md deleted file mode 100644 index ca0059084..000000000 --- a/fast_llm/models/custom/readme.md +++ /dev/null @@ -1,38 +0,0 @@ -# Custom model template - -The "custom" model is a template for customized training of a GPT-style model, -for example to fine-tune it for a particular class. -This is typically done as follows: - -1. Create a copy of the `custom` model, and rename it appropriately, ex. `my_model`, `MyModelTrainer`, etc. -2. If necessary, adjust the base classes to inherit from more abstract classes or another model. -ex. `MyModelData(AbstractData)` to re-implement data processing from scratch. -3. Add custom configuration fields in `config.py`. -4. Adapt or re-implement the data loading scheme in `MyModelData`. -5. Adapt or re-implement the preprocessing scheme in `MyModelBaseModel`. -6. Adapt or re-implement the model head, ex. change the task and/or add a custom loss. -7. If needed, adapt the huggingface interface to return outputs for the desired task. -8. Apply other changes as needed. -9. Add the new model to the registry (`models.auto.py`) so it can be used through the cli. -10. Run training with the new model, ex. `fast-llm train my_model [...]`. - -## Preprocessing variables and kwargs - -To pass additional parameters to the model during preprocessing, ex. a target for the loss or a runtime parameter, -simply add them to the returned `kwargs`. -Those kwargs will be passed directly to the `forward` method of each layer and can be used as needed. - -In some cases, it may be desirable to modify the `kwargs` inside a layer, -for example to pass additional data to other layers or to the backward pass. -This possible with certain caveats: - -* There is no direct support for autograd. Detaching tensors is recommended to prevent memory losses. -* Such modifications may be incompatible with pipeline parallelism, -as the data will not be transferred to pipeline-parallel devices. - -## Disclaimer - -Model customization is a work in progress. -Some abstractions may be missing or poorly implemented, -and some methods and variables may be hard-coded or very difficult to override. -We intend to address these issues in the future, but it will most likely incur some breaking changes in the interface. diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py deleted file mode 100644 index 587adad3e..000000000 --- a/fast_llm/models/custom/trainer.py +++ /dev/null @@ -1,15 +0,0 @@ -from fast_llm.models.custom.config import CustomTrainerConfig -from fast_llm.models.custom.data import CustomData -from fast_llm.models.gpt.trainer import GPTTrainer - - -class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): - # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). - def _get_data(self): - # TODO: Adjust signature if needed. - return CustomData( - config=self._config.data, - distributed_config=self._config.model.distributed, - vocab_size=self._config.model.base_model.vocab_size, - max_sequence_length=self._config.batch.sequence_length, - ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 10897cd9a..9cd77ff37 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,12 +4,23 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + AutoGPTHuggingfaceCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, +) from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -21,52 +32,6 @@ logger = logging.getLogger(__name__) -class GPTHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.gpt.conversion import AutoGPTHuggingfaceCheckpointHandler - - return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.name) - - -class AutoGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "auto" - - -class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "starcoder2" - - -class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llama" - - -class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "qwen2" - - -class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mistral" - - -class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mixtral" - - -class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "mtp_llama" - - -class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "dream" - - -class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "diffusion_llama" - - @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -151,14 +116,14 @@ class GPTModelConfig(FastLLMModelConfig): base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, - LlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, + LlamaCheckpointFormat, + Qwen2CheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + AprielHybridSSMCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py deleted file mode 100644 index 3cb954e1d..000000000 --- a/fast_llm/models/gpt/conversion.py +++ /dev/null @@ -1,856 +0,0 @@ -import abc -import dataclasses -import logging -import typing - -import torch -from transformers.configuration_utils import PretrainedConfig - -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ( - AutoStateDictCheckpointHandler, - ConstantExportParamConverter, - ConstantImportParamConverter, - IgnoreExportWeightConverter, - IgnoreImportParamConverter, - IgnoreImportWeightConverter, - MappedConfigParamConverter, - ParamConverter, - RenameParamConverter, - SplitWeightConverter, - WeightConverter, -) -from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.models.gpt.config import ( - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, - GPTBaseModelConfig, - GPTModelConfig, - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.gpt.external.diffusion_dream.configuration_dream import DreamConfig -from fast_llm.models.gpt.external.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig -from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig -from fast_llm.models.gpt.model import GPTModel -from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, div - -if typing.TYPE_CHECKING: - pass - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class HiddenSizeParamConverter(ParamConverter): - """ - Some HF models don't have a `head_dim` parameter, and instead use hidden_size // heads - """ - - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 3) - Assert.eq(len(self.export_names), 2) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - hidden_size, heads, head_size = fast_llm_values - Assert.eq(head_size * heads, hidden_size) - return hidden_size, heads - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - hidden_size, heads = export_values - return hidden_size, heads, div(hidden_size, heads) - - -class QueryWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings. - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.mixer.head_size, 0) - return (query,) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (query,) = weight - if self._config.transformer.mixer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.mixer.head_size, 0) - return (query,) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.mixer.head_size, 0) - return key, value - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - if self._config.transformer.mixer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.mixer.head_size, 0) - key_value = torch.cat([key[:], value[:]]) - return (key_value,) - - -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) - _config: GPTBaseModelConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) - - -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - architecture: typing.ClassVar[str] - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), - ConstantImportParamConverter( - fast_llm_names=( - ( - "embeddings_layer", - "position_embeddings", - "enabled", - ), - ), - fast_llm_value=False, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "mlp", "activation"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "intermediate_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "embeddings_layer", - "vocab_size", - ), - ), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=( - ( - "output_layer", - "tied_weight", - ), - ), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass - - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - - # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - converters += self._create_lm_head_converters() - - for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") - - return converters - - def _create_transformer_layer_converters( - self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False - ) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - names_bias_cls = [ - # Self-attn - ( - f"{fast_llm_layer_name}.mixer.query", - f"{hf_layer_name}.self_attn.q_proj", - # TODO: Fix - transformer_config.mixer.add_linear_biases, - QueryWeightConverter, - ), - ( - f"{fast_llm_layer_name}.mixer.key_value", - (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - # TODO: Fix - transformer_config.mixer.add_linear_biases, - KeyValueWeightConverter, - ), - ( - f"{fast_llm_layer_name}.mixer.dense", - f"{hf_layer_name}.self_attn.o_proj", - # TODO: Fix - transformer_config.mixer.add_linear_biases, - WeightConverter, - ), - # Norm - ( - f"{fast_llm_layer_name}.norm_1", - f"{hf_layer_name}.input_layernorm", - norm_bias, - WeightConverter, - ), - ( - f"{fast_llm_layer_name}.norm_2", - f"{hf_layer_name}.post_attention_layernorm", - norm_bias, - WeightConverter, - ), - ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: - converters += self._get_weight_and_bias_converters( - fast_llm_prefix, - () if ignore_export else hf_prefix, - use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, - ) - - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"{fast_llm_layer_name}.mlp.layer_1", - (), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"{fast_llm_layer_name}.mlp.layer_2", - (), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] - else: - converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") - return converters - - def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.output_layer.prediction_heads - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.output_layer.tied_weight: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # MTP-heads > 0 are thrown away - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) - - return converters - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - -class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - HiddenSizeParamConverter( - fast_llm_names=( - ("transformer", "hidden_size"), - ("transformer", "mixer", "heads"), - ("transformer", "mixer", "head_size"), - ), - export_names=(("hidden_size",), ("num_attention_heads",)), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary", "type"),), - fast_llm_value=DefaultRotaryConfig.dynamic_type_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=True - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="layer_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=False), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=True - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="layer_norm", - ), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - f"{hf_prefix}.mlp.c_fc", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.c_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler, abc.ABC): - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mixer", "head_size"),), - export_names=(("head_dim",),), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mixer", "add_linear_biases"),), fast_llm_value=False - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "add_linear_biases"),), fast_llm_value=False - ), - LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary"),), - export_names=( - ("rope_theta",), - ("rope_scaling",), - ), - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - ] - - -@dataclasses.dataclass -class LLamaRotaryParamConverter(ParamConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 2) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: - rotary_scaling = { - "rope_type": "default", - } - elif type(rotary_config) is Llama3RotaryConfig: - rotary_scaling = { - "rope_type": "llama3", - "factor": rotary_config.scale_factor, - "low_freq_factor": rotary_config.low_frequency_factor, - "high_freq_factor": rotary_config.high_frequency_factor, - "original_max_position_embeddings": rotary_config.original_context_length, - } - elif type(rotary_config) is YarnRotaryConfig: - rotary_scaling = { - "rope_type": "yarn", - "attention_factor": rotary_config.attention_factor, - "beta_fast": rotary_config.beta_fast, - "beta_slow": rotary_config.beta_slow, - "original_max_position_embeddings": rotary_config.original_context_length, - } - else: - raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") - - return rotary_config.theta, rotary_scaling - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - rotary_theta, rope_scaling = export_values - rotary_type = "default" if rope_scaling in (None, MISSING) else rope_scaling.get("rope_type", "default") - rotary_config = { - "type": rotary_type, - "theta": rotary_theta, - } - if rotary_type == "default": - pass - elif rotary_type == "llama3": - rotary_config.update( - { - "scale_factor": rope_scaling.get("factor", DEFAULT), - "low_frequency_factor": rope_scaling.get("low_freq_factor", DEFAULT), - "high_frequency_factor": rope_scaling.get("high_freq_factor", DEFAULT), - "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), - } - ) - elif rotary_type == "yarn": - rotary_config.update( - { - "attention_factor": rope_scaling.get("attention_factor", DEFAULT), - "beta_fast": rope_scaling.get("beta_fast", DEFAULT), - "beta_slow": rope_scaling.get("beta_slow", DEFAULT), - "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), - } - ) - return (rotary_config,) # RotaryConfig.from_dict(rotary_config) - - -class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "LlamaForCausalLM" - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -@dataclasses.dataclass -class IgnoreImportQwen2SlidingWindowParamsConverter(ParamConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 0) - Assert.eq(len(self.export_names), 0) - self.export_names = (("use_sliding_window",), ("sliding_window",), ("max_window_layers",)) - - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (MISSING, MISSING, MISSING) - - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - # Default value for use_sliding_window in Qwen2 HF config is False - if export_values[0] != MISSING and export_values[0] == True: - logger.warning( - f"The configuration parameters `{self.export_names[0]}={export_values[0]}`," - f" `{self.export_names[1]}={export_values[1]}`, `{self.export_names[2]}={export_values[2]}`" - f" are ignored during conversion." - f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration." - ) - return () - - -class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - HiddenSizeParamConverter( - fast_llm_names=( - ("transformer", "hidden_size"), - ("transformer", "mixer", "heads"), - ("transformer", "mixer", "head_size"), - ), - export_names=(("hidden_size",), ("num_attention_heads",)), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "gated"),), fast_llm_value=True), - # TODO: Fix - ConstantImportParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" - ), - LLamaRotaryParamConverter( - fast_llm_names=(("transformer", "mixer", "rotary"),), - export_names=( - ("rope_theta",), - ("rope_scaling",), - ), - ), - ConstantImportParamConverter( - fast_llm_names=(("output_layer", "normalization", "type"),), - fast_llm_value="rms_norm", - ), - IgnoreImportQwen2SlidingWindowParamsConverter(), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - -class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MistralForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.mlp.down_proj.weight", - self._model.config.base_model, - ), - ] - - -class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MixtralForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("transformer", "mlp", "type"),), fast_llm_value="moe"), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "mlp", "routing"),), fast_llm_value=RoutingType.topk - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "experts"),), export_names=(("num_local_experts",),) - ), - RenameParamConverter( - fast_llm_names=(("transformer", "mlp", "experts_per_token"),), - export_names=(("num_experts_per_tok",),), - ), - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.mlp.experts - return [ - WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - tuple( - f"{hf_prefix}.block_sparse_moe.experts.{i}.{w}.weight" - for i in range(num_experts) - for w in ("w1", "w3") - ), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - tuple(f"{hf_prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(num_experts)), - self._model.config.base_model, - ), - ] - - -class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler): - from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama - - format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" - modeling_file = modeling_mtp_llama.__file__ - configuration_file = configuration_mtp_llama.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", - "AutoModel": "modeling_mtp_llama.MTPLlamaModel", - "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=( - ( - "output_layer", - "prediction_heads", - ), - ), - export_names=(("prediction_heads",),), - ), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: BlockConfig = self._model.config.base_model.transformer - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - # TODO: Fix - transformer_config.mlp.add_linear_biases, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - # TODO: Fix - transformer_config.mlp.add_linear_biases, - MLPLayer2Converter, - ), - ] - - # Override base method to handle the MTP heads - def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.output_layer.prediction_heads - norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) - converters = [] - - # Next-token prediction head - # Transformer layer is already handled in the transformer layer converters - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.mtp_norms.0", norm_bias - ) - # Multi-token prediction head - for i in range(1, prediction_heads): - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", - f"model.mtp_heads.{i - 1}", - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", - f"model.mtp_norms.{i}", - norm_bias, - ) - # Output weights - if self._model.config.base_model.output_layer.tied_weight: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - return converters - - -class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler): - """ - Handler for DiffusionDream Huggingface checkpoints. - Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin), - but overrides _create_config_converters to update architectures and auto_map. - """ - - from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream - - format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DreamModel" - modeling_file = modeling_dream.__file__ - configuration_file = configuration_dream.__file__ - generation_utils_file = generation_utils.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DreamConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_dream.DreamConfig", - "AutoModel": "modeling_dream.DreamModel", - }, - ), - ] - - -class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler): - - from fast_llm.models.gpt.external.diffusion_llama import ( - configuration_diffusion_llama, - generation_utils, - modeling_diffusion_llama, - ) - - format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DiffusionLlamaModel" - modeling_file = modeling_diffusion_llama.__file__ - configuration_file = configuration_diffusion_llama.__file__ - generation_utils_file = generation_utils.__file__ - configuration_cls: typing.ClassVar[type[PretrainedConfig]] = DiffusionLlamaConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig", - "AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel", - }, - ), - # TODO: include when the mask diffusion training is implemented; - # since the imported model (llama) for CPT doesn't have it but the exported model (diffusion llama) does need to have this token. - # RenameParamConverter( - # fast_llm_names=(("mask_token_id",),), - # export_names=(("mask_token_id",),), - # ), - ] - - -class AutoGPTHuggingfaceCheckpointHandler( - AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC -): - - handler_map = { - Starcoder2GPTHuggingfaceCheckpointFormat.name: Starcoder2HuggingfaceCheckpointHandler, - LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler, - Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, - MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, - MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, - MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, - DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, - DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, - } diff --git a/fast_llm/models/gpt/conversion/__init__.py b/fast_llm/models/gpt/conversion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py new file mode 100644 index 000000000..5b32c481d --- /dev/null +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -0,0 +1,374 @@ +import math +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.mistral import ( + MistralBaseModelConverter, + MistralBlockConverter, + MistralDecoderConverter, + MistralHeadConverter, + MistralHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert, safe_merge_dicts + + +class AprielDiscreteMamba2Converter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "type": "discrete_mamba_2", + "state_size": config["ssm_cfg"]["d_state"], + "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "add_linear_biases": config["ssm_cfg"]["bias"], + "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, + "n_qk_heads": config["ssm_cfg"]["n_qk_heads"], + "n_v_heads": config["ssm_cfg"]["n_v_heads"], + "chunk_size": config["ssm_cfg"]["chunk_size"], + } + + @classmethod + def export_config(cls, config: DiscreteMamba2Config) -> dict: + cls._check_config(config) + return { + "ssm_cfg": { + "d_state": config.state_size, + "d_inner": config.d_inner, + "bias": config.add_linear_biases, + "conv_bias": ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + "n_qk_heads": config.n_qk_heads, + "n_v_heads": config.n_v_heads, + "chunk_size": config.chunk_size, + } + } + + @classmethod + def _check_config(cls, config: DiscreteMamba2Config) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), DiscreteMamba2Config) + Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.c_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: DiscreteMamba2Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj", + f"{hf_prefix}.in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.conv1d", + ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + drop_on_export=drop_on_export, + ), + *( + [] + if config.add_linear_biases + else [ + get_parameter_converter( + f"{fast_llm_prefix}.z_bias", + f"{hf_prefix}.z_bias", + drop_on_export=drop_on_export, + ) + ] + ), + get_parameter_converter( + f"{fast_llm_prefix}.D", + f"{hf_prefix}.D", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class AprielMamba2Converter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "type": "mamba_2", + "state_size": config["ssm_cfg"]["d_state"], + "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "add_linear_biases": config["ssm_cfg"]["bias"], + "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, + "d_xb": config["ssm_cfg"].get("d_xb") or hidden_size, + "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, + "dt_rank": ( + math.ceil(hidden_size) + if config["ssm_cfg"].get("dt_rank", "auto") == "auto" + else config["ssm_cfg"]["dt_rank"] + ), + "repeat_kv_before_conv": config["ssm_cfg"].get("repeat_kv_before_conv", True), + } + + @classmethod + def export_config(cls, config: Mamba2Config) -> dict: + cls._check_config(config) + return { + "ssm_cfg": { + "d_state": config.state_size, + "d_inner": config.d_inner, + "bias": config.add_linear_biases, + "conv_bias": ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + "d_xb": config.d_xb, + "dt_proj_bias": ( + config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled + ), + "dt_rank": config.dt_rank, + "repeat_kv_before_conv": config.repeat_kv_before_conv, + } + } + + @classmethod + def _check_config(cls, config: Mamba2Config) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), Mamba2Config) + Assert.incl(config.z_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.x_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.b_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.c_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dt_input_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: Mamba2Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + # TODO: Conv + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj", + f"{hf_prefix}.in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_in_proj", + f"{hf_prefix}.dt_in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_proj", + f"{hf_prefix}.dt_proj", + config.add_linear_biases if config.dt_layer.bias.enabled is None else config.dt_layer.bias.enabled, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.conv1d", + ( + config.add_linear_biases + if config.convolution_layer.bias.enabled is None + else config.convolution_layer.bias.enabled + ), + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.D", + f"{hf_prefix}.D", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): + mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter + + +class AprielMamba2BlockConverter(MistralBlockConverter): + mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter + + +class AprielBlockConverter: + layout_names = { + AttentionConfig: "t", + Mamba2Config: "m2", + DiscreteMamba2Config: "m2d", + } + _converter_classes = { + AttentionConfig: MistralBlockConverter, + Mamba2Config: AprielMamba2BlockConverter, + DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, + } + _config_classes = {value: key for key, value in layout_names.items()} + + @classmethod + def import_config(cls, config: dict, hidden_size: int, layout_name: str = "t") -> dict: + return cls._converter_classes[cls._config_classes[layout_name]].import_config(config, hidden_size) + + @classmethod + def export_config(cls, config) -> dict: + return cls._converter_classes[type(config.mixer)].export_config(config) + + @classmethod + def get_converters( + cls, + config: DecoderBlockConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return cls._converter_classes[type(config.mixer)].get_converters( + config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export + ) + + +class AprielDecoderConverter(MistralDecoderConverter): + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + layout = config["hybrid_block_layout"] + if len(layout) == 1: + return { + "block": cls.block_converter_class.import_config(config, hidden_size, layout[0]), + "num_blocks": config["num_hidden_layers"], + } + else: + return { + "type": "pattern", + "blocks": { + layout_name: cls.block_converter_class.import_config(config, hidden_size, layout_name) + for layout_name in set(layout) + }, + "pattern": layout, + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: BlockSequenceConfig) -> dict: + if type(config) is FixedBlockSequenceConfig: + block_configs = [config.block] + pattern_block_configs = [config.block] + elif type(config) is PatternBlockSequenceConfig: + block_configs = config.blocks.values() + pattern_block_configs = [config.blocks[block_name] for block_name in config.pattern] + else: + raise NotImplementedError() + # There may be all sorts of blocks, but `safe_merge_dicts` ensures they are compatible. + return safe_merge_dicts( + *[cls.block_converter_class.export_config(block_config) for block_config in block_configs], + { + "num_hidden_layers": config.num_blocks, + "hybrid_block_layout": [ + cls.block_converter_class.layout_names[type(block_config.mixer)] + for block_config in pattern_block_configs + ], + }, + ) + + @classmethod + def get_converters( + cls, + config: PatternBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + fast_llm_layer_start: int = 1, + ) -> list[WeightConverter]: + converters = [] + for block_index in range(config.num_blocks): + block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters + + +class AprielHeadConverter(MistralHeadConverter): + block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter + + +class AprielBaseModelConverter(MistralBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[AprielDecoderConverter]] = AprielDecoderConverter + head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter + + +class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AprielHybridSSMCheckpointFormat + architecture: typing.ClassVar[str] = "AprielHybridSSMForCausalLM" + base_model_converter_class: typing.ClassVar[type[AprielBaseModelConverter]] = AprielBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig + + return AprielHybridSSMConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.apriel_hybrid_ssm import ( + configuration_apriel_hybrid_ssm, + modeling_apriel_hybrid_ssm, + ) + + return configuration_apriel_hybrid_ssm.__file__, modeling_apriel_hybrid_ssm.__file__, None + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_apriel_hybrid_ssm.AprielHybridSSMConfig", + "AutoModel": "modeling_apriel_hybrid_ssm.AprielHybridSSMModel", + "AutoModelForCausalLM": "modeling_apriel_hybrid_ssm.AprielHybridSSMForCausalLM", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py new file mode 100644 index 000000000..659d1f12c --- /dev/null +++ b/fast_llm/models/gpt/conversion/auto.py @@ -0,0 +1,38 @@ +import abc + +from fast_llm.engine.checkpoint.external import AutoStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.models.gpt.conversion.apriel import AprielHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, +) +from fast_llm.models.gpt.conversion.diffusion_dream import DiffusionDreamHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.diffusion_llama import DiffusionLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mistral import MistralHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mixtral import MixtralHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.mtp_llama import MTPLlamaHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.qwen2 import Qwen2HuggingfaceCheckpointHandler + + +class AutoGPTHuggingfaceCheckpointHandler( + AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC +): + + handler_map = { + LlamaCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler, + Qwen2CheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, + MistralCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, + MixtralCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + MTPLlamaCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, + DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, + DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, + AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, + } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py new file mode 100644 index 000000000..7c06906ad --- /dev/null +++ b/fast_llm/models/gpt/conversion/config.py @@ -0,0 +1,49 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler + + +class GPTHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.gpt.conversion.auto import AutoGPTHuggingfaceCheckpointHandler + + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.name) + + +class AutoGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "auto" + + +class LlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llama" + + +class Qwen2CheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "qwen2" + + +class MistralCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mistral" + + +class MixtralCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mixtral" + + +class MTPLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "mtp_llama" + + +class DiffusionDreamCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "dream" + + +class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "diffusion_llama" + + +class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "apriel_hybrid_ssm" diff --git a/fast_llm/models/gpt/conversion/diffusion_dream.py b/fast_llm/models/gpt/conversion/diffusion_dream.py new file mode 100644 index 000000000..43742dd68 --- /dev/null +++ b/fast_llm/models/gpt/conversion/diffusion_dream.py @@ -0,0 +1,44 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import DiffusionDreamCheckpointFormat +from fast_llm.models.gpt.conversion.qwen2 import Qwen2HuggingfaceCheckpointHandler +from fast_llm.utils import safe_merge_dicts + + +class DiffusionDreamHuggingfaceCheckpointHandler(Qwen2HuggingfaceCheckpointHandler): + """ + Handler for DiffusionDream Huggingface checkpoints. + Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin), + but overrides _create_config_converters to update architectures and auto_map. + """ + + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.diffusion_dream.configuration_dream import DreamConfig + + return DreamConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.diffusion_dream import configuration_dream, generation_utils, modeling_dream + + return configuration_dream.__file__, modeling_dream.__file__, generation_utils.__file__ + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_dream.DreamConfig", + "AutoModel": "modeling_dream.DreamModel", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/diffusion_llama.py b/fast_llm/models/gpt/conversion/diffusion_llama.py new file mode 100644 index 000000000..3343e5f1e --- /dev/null +++ b/fast_llm/models/gpt/conversion/diffusion_llama.py @@ -0,0 +1,42 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import DiffusionLlamaCheckpointFormat +from fast_llm.models.gpt.conversion.llama import LlamaHuggingfaceCheckpointHandler +from fast_llm.utils import safe_merge_dicts + + +class DiffusionLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.diffusion_llama.configuration_diffusion_llama import DiffusionLlamaConfig + + return DiffusionLlamaConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.diffusion_llama import ( + configuration_diffusion_llama, + generation_utils, + modeling_diffusion_llama, + ) + + return configuration_diffusion_llama.__file__, modeling_diffusion_llama.__file__, generation_utils.__file__ + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_diffusion_llama.DiffusionLlamaConfig", + "AutoModel": "modeling_diffusion_llama.DiffusionLlamaModel", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py new file mode 100644 index 000000000..1162db4de --- /dev/null +++ b/fast_llm/models/gpt/conversion/llama.py @@ -0,0 +1,575 @@ +import logging +import typing + +import torch + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + IgnoreExportWeightConverter, + IgnoreImportWeightConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelHeadConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat +from fast_llm.models.gpt.model import GPTModel +from fast_llm.tensor import SafeTensorSlice +from fast_llm.utils import Assert, div, safe_merge_dicts + +logger = logging.getLogger(__name__) + + +def get_parameter_converter( + fast_llm_name: str | tuple[str, ...], + hf_name: str | tuple[str, ...], + cls=WeightConverter, + config=None, + drop_on_export: bool = False, + drop_on_import: bool = False, +) -> WeightConverter: + if isinstance(fast_llm_name, str): + fast_llm_name = (fast_llm_name,) + if isinstance(hf_name, str): + hf_name = (hf_name,) + if drop_on_export: + cls = IgnoreExportWeightConverter + if drop_on_import: + cls = IgnoreImportWeightConverter + return cls( + () if drop_on_import else fast_llm_name, + () if drop_on_export else hf_name, + config, + ) + + +def get_weight_and_bias_converters( + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + config=None, + drop_on_export: bool = False, + drop_on_import: bool = False, +) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + get_parameter_converter( + () if drop_on_import else tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + () if drop_on_export else tuple(f"{prefix}.weight" for prefix in hf_prefix), + cls, + config, + drop_on_export, + drop_on_import, + ) + ] + if use_bias: + converters.append( + get_parameter_converter( + () if drop_on_import else tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + () if drop_on_export else tuple(f"{prefix}.bias" for prefix in hf_prefix), + cls, + config, + drop_on_export, + drop_on_import, + ) + ) + return converters + + +class LlamaNormalizationConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return {"type": "rms_norm", "epsilon": config["rms_norm_eps"]} + + @classmethod + def export_config(cls, config: RMSNormalizationConfig) -> dict: + Assert.custom(isinstance, config, RMSNormalizationConfig) + assert not config.zero_centered + return {"rms_norm_eps": config.epsilon} + + @classmethod + def get_converters( + cls, + config: RMSNormalizationConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return get_weight_and_bias_converters( + fast_llm_prefix, + () if drop_on_export else hf_prefix, + False, + IgnoreExportWeightConverter if drop_on_export else WeightConverter, + ) + + +class LlamaMLPConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "intermediate_size": config["intermediate_size"], + "add_linear_biases": config["mlp_bias"], + "activation": ActivationType.from_hf_name(config["hidden_act"]), + "gated": True, + } + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + Assert.custom(isinstance, config, MLPConfig) + Assert.incl(config.layer_1.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) + assert config.gated + return { + "intermediate_size": config.intermediate_size, + "mlp_bias": config.add_linear_biases, + "hidden_act": config.activation.hf_name, + } + + @classmethod + def get_converters( + cls, + config: MLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_1", + (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), + config.add_linear_biases, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.layer_2", + f"{hf_prefix}.down_proj", + config.add_linear_biases, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + + +class MLPLayer2Converter(WeightConverter): + # Similar to SplitWeightConverter, but handles the optional MLP transpose. + # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class LlamaAttentionConverter: + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + try: + rope_type = config["rope_scaling"]["rope_type"] + except (KeyError, TypeError): + rope_type = "default" + rotary_config = { + "type": rope_type, + "theta": config["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": config["factor"], + "low_frequency_factor": config["low_freq_factor"], + "high_frequency_factor": config["high_freq_factor"], + "original_context_length": config["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": config["attention_factor"], + "beta_fast": config["beta_fast"], + "beta_slow": config["beta_slow"], + "original_context_length": config["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + out = { + "rotary": rotary_config, + "heads": config["num_attention_heads"], + "head_groups": config["num_key_value_heads"], + "head_size": config.get("head_dim"), + "add_linear_biases": config["attention_bias"], + "dropout": config["attention_dropout"], + } + if out["head_size"] is None: + out["head_size"] = div(hidden_size, out["heads"]) + + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + cls._check_config(config) + Assert.eq(config.softmax_scale_power, 0.5) + out = { + "num_attention_heads": config.heads, + "num_key_value_heads": config.head_groups, + "head_dim": config.head_size, + "attention_bias": config.add_linear_biases, + "attention_dropout": config.dropout, + "rope_theta": config.rotary.theta, + } + if type(config.rotary) is DefaultRotaryConfig: + pass + elif type(config.rotary) is Llama3RotaryConfig: + out["rope_scaling"] = { + "rope_type": "llama3", + "factor": config.rotary.scale_factor, + "low_freq_factor": config.rotary.low_frequency_factor, + "high_freq_factor": config.rotary.high_frequency_factor, + "original_max_position_embeddings": config.rotary.original_context_length, + } + elif type(config.rotary) is YarnRotaryConfig: + out["rope_scaling"] = { + "rope_type": "yarn", + "attention_factor": config.rotary.attention_factor, + "beta_fast": config.rotary.beta_fast, + "beta_slow": config.rotary.beta_slow, + "original_max_position_embeddings": config.rotary.original_context_length, + } + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + + return out + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + # Opportunity to make derived classes less constrained. + Assert.is_(type(config), AttentionConfig) + Assert.incl(config.query_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.key_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.value_layer.bias.enabled, (None, config.add_linear_biases)) + Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + config.add_linear_biases, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + config.add_linear_biases, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class QueryWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.rotary.complex_format: + query = convert_rotary_complex_to_real(query[:], self._config.head_size, 0) + return (query,) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (query,) = weight + if self._config.rotary.complex_format: + query = convert_rotary_real_to_complex(query[:], self._config.head_size, 0) + return (query,) + + +class KeyValueWeightConverter(WeightConverter): + # Hf uses the real format for rotary embeddings, and keeps the key and value separate. + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + if self._config.rotary.complex_format: + key = convert_rotary_complex_to_real(key, self._config.head_size, 0) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + if self._config.rotary.complex_format: + key = convert_rotary_real_to_complex(key[:], self._config.head_size, 0) + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +class LlamaBlockConverter: + mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter + mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "mixer": cls.mixer_converter_class.import_config(config, hidden_size), + "mlp": cls.mlp_converter_class.import_config(config), + "normalization": cls.normalization_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: DecoderBlockConfig) -> dict: + Assert.custom(isinstance, config, DecoderBlockConfig) + return safe_merge_dicts( + cls.mixer_converter_class.export_config(config.mixer), + cls.mlp_converter_class.export_config(config.mlp), + cls.normalization_converter_class.export_config(config.normalization), + ) + + @classmethod + def get_converters( + cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False + ) -> list[WeightConverter]: + return [ + *cls.mixer_converter_class.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + f"{hf_prefix}.self_attn", + drop_on_export, + ), + *cls.mlp_converter_class.get_converters( + config.mlp, + f"{fast_llm_prefix}.mlp", + f"{hf_prefix}.mlp", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.input_layernorm", + drop_on_export, + ), + *cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.post_attention_layernorm", + drop_on_export, + ), + ] + + +class LlamaDecoderConverter: + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return { + "block": cls.block_converter_class.import_config(config, hidden_size), + "num_blocks": config["num_hidden_layers"], + } + + @classmethod + def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + # TODO: Support PatternBlockSequenceConfig with compatible configs. + Assert.custom(isinstance, config, FixedBlockSequenceConfig) + return safe_merge_dicts( + cls.block_converter_class.export_config(config.block), + {"num_hidden_layers": config.num_blocks}, + ) + + @classmethod + def get_converters( + cls, + config: FixedBlockSequenceConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + fast_llm_layer_start: int = 1, + ) -> list[WeightConverter]: + converters = [] + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + return converters + + +class LlamaEmbeddingsConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "vocab_size": config["vocab_size"], + "hidden_size": config["hidden_size"], + } + + @classmethod + def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) + assert not config.position_embeddings.enabled + return { + "vocab_size": config.vocab_size, + "hidden_size": config.hidden_size, + } + + @classmethod + def get_converters( + cls, config: LanguageModelEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str + ) -> list[WeightConverter]: + return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] + + +class LlamaHeadConverter: + normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "tied_weight": config["tie_word_embeddings"], + "normalization": cls.normalization_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + Assert.custom(isinstance, config, LanguageModelHeadConfig) + return safe_merge_dicts( + cls.normalization_converter_class.export_config(config.normalization), + {"tie_word_embeddings": config.tied_weight}, + ) + + @classmethod + def get_converters( + cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + ) -> list[WeightConverter]: + converters = [] + for prediction_distance in range(config.prediction_heads): + if prediction_distance > 0: + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", + "", + drop_on_export=True, + ) + converters += cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"model.norm", + drop_on_export=prediction_distance > 0, + ) + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.{start_index}.output_weights", + "lm_head.weight", + drop_on_import=config.tied_weight, + ) + ) + + return converters + + +class LlamaBaseModelConverter: + # TODO: Peft? + decoder_converter_class: typing.ClassVar[type[LlamaDecoderConverter]] = LlamaDecoderConverter + embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + head_converter_class: typing.ClassVar[type[LlamaHeadConverter]] = LlamaHeadConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "embeddings_layer": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), + "output_layer": cls.head_converter_class.import_config(config), + } + + @classmethod + def export_config(cls, config: GPTBaseModelConfig) -> dict: + Assert.custom(isinstance, config, GPTBaseModelConfig) + return safe_merge_dicts( + cls.embeddings_converter_class.export_config(config.embeddings_layer), + cls.decoder_converter_class.export_config(config.decoder), + cls.head_converter_class.export_config(config.output_layer), + ) + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: + return [ + *cls.embeddings_converter_class.get_converters(config.embeddings_layer, "layers.0", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), + *cls.head_converter_class.get_converters( + config.output_layer, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 + ), + ] + + def _create_weight_converters( + self, + ) -> list[WeightConverter]: + base_model_config = self._model.config.base_model + self.embeddings_converter_class.get_converters(base_model_config.embeddings_layer, "layers.0", "model") + converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") + self.head_converter_class.get_converters( + base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 + ) + return converters + + +class LlamaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LlamaCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" + base_model_converter_class: typing.ClassVar[type[LlamaBaseModelConverter]] = LlamaBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.LlamaConfig diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py new file mode 100644 index 000000000..4673f5b2c --- /dev/null +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -0,0 +1,61 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, +) +from fast_llm.utils import safe_merge_dicts + + +class MistralAttentionConverter(LlamaAttentionConverter): + @classmethod + def import_config(cls, config: dict, hidden_size: int) -> dict: + return safe_merge_dicts(super().import_config(config, hidden_size), {"window_size": config["sliding_window"]}) + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + return safe_merge_dicts( + super().export_config(config), + {"sliding_window": config.window_size}, + ) + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + # Mistral doesn't support biases. + assert not config.add_linear_biases + + +class MistralBlockConverter(LlamaBlockConverter): + mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + + +class MistralDecoderConverter(LlamaDecoderConverter): + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter + + +class MistralHeadConverter(LlamaHeadConverter): + block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter + + +class MistralBaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MistralDecoderConverter]] = MistralDecoderConverter + head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter + + +class MistralHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MistralCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" + base_model_converter_class: typing.ClassVar[type[MistralBaseModelConverter]] = MistralBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.MistralConfig diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py new file mode 100644 index 000000000..428c2d3a3 --- /dev/null +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -0,0 +1,88 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig +from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.mistral import ( + MistralBaseModelConverter, + MistralBlockConverter, + MistralDecoderConverter, + MistralHeadConverter, + MistralHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert, safe_merge_dicts + + +class MixtralMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts( + super().import_config(config), + { + "type": "moe", + "experts": config["num_local_experts"], + "experts_per_token": config["num_experts_per_tok"], + }, + ) + + @classmethod + def export_config(cls, config: MoEMLPConfig) -> dict: + Assert.custom(isinstance, config, MoEMLPConfig) + assert not config.add_linear_biases + return safe_merge_dicts( + super().export_config(config), + { + "num_local_experts": config.experts, + "num_experts_per_tok": config.experts_per_token, + }, + ) + + @classmethod + def get_converters( + cls, + config: MoEMLPConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.router", + () if drop_on_export else (f"{hf_prefix}.router",), + config.add_linear_biases, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export), + ] + + +class MixtralBlockConverter(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter + + +class MixtralDecoderConverter(MistralDecoderConverter): + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter + + +class MixtralHeadConverter(MistralHeadConverter): + block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter + + +class MixtralBaseModelConverter(MistralBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MixtralDecoderConverter]] = MixtralDecoderConverter + head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter + + +class MixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MixtralCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" + base_model_converter_class: typing.ClassVar[type[MixtralBaseModelConverter]] = MixtralBaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.MixtralConfig diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py new file mode 100644 index 000000000..194c263f9 --- /dev/null +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -0,0 +1,95 @@ +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaBaseModelConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, + get_parameter_converter, +) +from fast_llm.utils import safe_merge_dicts + + +class MTPLlamaHeadConverter(LlamaHeadConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts( + super().import_config(config), + {"prediction_heads": config["prediction_heads"]}, + ) + + @classmethod + def export_config(cls, config: LanguageModelHeadConfig) -> dict: + return safe_merge_dicts( + super().export_config(config), + {"prediction_heads": config.prediction_heads}, + ) + + @classmethod + def get_converters( + cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + ) -> list[WeightConverter]: + converters = [] + for prediction_distance in range(config.prediction_heads): + if prediction_distance > 0: + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", + f"model.mtp_heads.{prediction_distance - 1}", + ) + converters += cls.normalization_converter_class.get_converters( + config.normalization, + f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"model.mtp_norms.{prediction_distance}", + ) + converters.append( + get_parameter_converter( + f"{fast_llm_prefix}.{start_index}.output_weights", + "lm_head.weight", + drop_on_import=config.tied_weight, + ) + ) + + return converters + + +class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): + head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter + + +class MTPLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" + base_model_converter_class: typing.ClassVar[type[MTPLlamaBaseModelConverter]] = MTPLlamaBaseModelConverter + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_mtp_llama.MTPLlamaConfig", + "AutoModel": "modeling_mtp_llama.MTPLlamaModel", + "AutoModelForCausalLM": "modeling_mtp_llama.MTPLlamaForCausalLM", + }, + }, + ) + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.mtp_llama.configuration_mtp_llama import MTPLlamaConfig + + return MTPLlamaConfig + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.mtp_llama import configuration_mtp_llama, modeling_mtp_llama + + return configuration_mtp_llama.__file__, modeling_mtp_llama.__file__, None diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py new file mode 100644 index 000000000..a8bc33454 --- /dev/null +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -0,0 +1,62 @@ +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat +from fast_llm.models.gpt.conversion.llama import ( + LlamaAttentionConverter, + LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, + LlamaHeadConverter, + LlamaHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert + + +class Qwen2AttentionConverter(LlamaAttentionConverter): + # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + + @classmethod + def _check_config(cls, config: AttentionConfig) -> None: + Assert.is_(type(config), AttentionConfig) + # There are multiple ways to enable biases on QKV only + if config.add_linear_biases: + Assert.incl(config.query_layer.bias.enabled, (None, True)) + Assert.incl(config.key_layer.bias.enabled, (None, True)) + Assert.incl(config.value_layer.bias.enabled, (None, True)) + Assert.is_(config.dense_layer.bias.enabled, False) + else: + Assert.is_(config.query_layer.bias.enabled, True) + Assert.is_(config.key_layer.bias.enabled, True) + Assert.is_(config.value_layer.bias.enabled, True) + Assert.incl(config.dense_layer.bias.enabled, (None, False)) + + +class Qwen2BlockConverter(LlamaBlockConverter): + mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + + +class Qwen2DecoderConverter(LlamaDecoderConverter): + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter + + +class Qwen2HeadConverter(LlamaHeadConverter): + block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter + + +class Qwen2BaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[Qwen2DecoderConverter]] = Qwen2DecoderConverter + head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter + + +class Qwen2HuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = Qwen2CheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" + base_model_converter_class: typing.ClassVar[type[Qwen2BaseModelConverter]] = Qwen2BaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls): + import transformers + + return transformers.Qwen2Config diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index bbe7ae43f..f63bd76f8 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,8 +1,8 @@ import typing from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig -from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mlp.config import MoEMLPConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -14,7 +14,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: BlockConfig + meta: "ParameterMeta", config: DecoderBlockConfig, hidden_size: int ) -> typing.Callable[["torch.Tensor", "Distributed"], None]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) @@ -22,13 +22,13 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: # Generator unused. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) if "query" in meta.tensor_name or "key_value" in meta.tensor_name or "dense" in meta.tensor_name: - tensor_ = _init_attention_megatron(config, meta, tensor, distributed) + tensor_ = _init_attention_megatron(config, meta, tensor, distributed, hidden_size) elif "position_embeddings" in meta.tensor_name: tensor_ = _init_position_embeddings_megatron(meta, tensor, distributed) elif "mlp.router.weight" in meta.tensor_name: tensor_ = _init_moe_router_megatron(meta, tensor, distributed) elif isinstance(config.mlp, MoEMLPConfig) and config.mlp.experts > 1 and "mlp.layer_" in meta.tensor_name: - tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) + tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed, hidden_size) elif "mlp.layer_2" in meta.tensor_name: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: @@ -51,7 +51,11 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: DecoderBlockConfig, + meta: "ParameterMeta", + tensor: "torch.Tensor", + distributed: "Distributed", + hidden_size: int, ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -63,7 +67,7 @@ def _init_attention_megatron( meta, dense_tensor_ := tensor.new_empty( config.mixer.head_size * config.mixer.heads, - config.hidden_size, + hidden_size, ), generator, ) @@ -75,7 +79,7 @@ def _init_attention_megatron( config.mixer.head_groups, heads_per_group + 2, config.mixer.head_size, - config.hidden_size, + hidden_size, ), generator, ) @@ -141,19 +145,23 @@ def _init_moe_router_megatron( def _init_moe_mlp_megatron( - config: BlockConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: DecoderBlockConfig, + meta: "ParameterMeta", + tensor: "torch.Tensor", + distributed: "Distributed", + hidden_size: int, ) -> "torch.Tensor": assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator # self.param_init_method(self, tensor, generator) state = generator.get_state() weight_1 = tensor.new_empty( - config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, config.hidden_size + config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, hidden_size ) - weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, config.hidden_size) + weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, hidden_size) for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.experts), weight_2.chunk(config.mlp.experts)): meta.param_init_method(meta, chunk_1, generator) - chunk_2_ = chunk_2.new_empty(config.hidden_size, config.mlp.intermediate_size) + chunk_2_ = chunk_2.new_empty(hidden_size, config.mlp.intermediate_size) meta.param_init_method(meta, chunk_2_, generator) chunk_2.copy_(chunk_2_.t()) if "layer_1.weight" in meta.tensor_name: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8b2947837..b7d751a61 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -4,17 +4,15 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef +from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.block.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -37,79 +35,19 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): - self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) + self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + param.init_parameter = get_init_megatron( + param, self._config.decoder.block, config.embeddings_layer.hidden_size + ) # Noqa # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) - def _get_output_layers(self) -> list[Layer]: - layers = [] - for i in range(self._config.output_layer.prediction_heads): - if i > 0: - layers.append( - self._get_block( - # TODO MTP: which index? - max(self._config.transformer.num_layers + i, 1), - f"MPT head {i} block", - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - i < self._config.output_layer.prediction_heads - 1, - ) - ) - layers.append(self._get_head(i)) - return layers - def get_layers(self) -> list[Layer]: - return [ - self._get_embeddings(), - *[ - self._get_block( - i + 1, - f"Block {i + 1}", - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - self._config.output_layer.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, - ) - for i in range(self._config.transformer.num_layers) - ], - *self._get_output_layers(), - ] - - def _get_block( - self, - block_index: int, - name: str, - return_input: bool = False, - ): - return self._config.transformer.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - return_input=return_input, - ) - - def _get_embeddings(self): - return self._config.embeddings_layer.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - - def _get_head(self, prediction_distance): - return self._config.output_layer.get_layer( - self._distributed_config, - self._config.embeddings_layer, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - prediction_distance=prediction_distance, - ) + return self._config.get_blocks(self._distributed_config) def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType @@ -339,10 +277,6 @@ def preprocess( def embedding(self) -> LanguageModelEmbedding: return self.layers[0] - @property - def transformer_layers(self) -> list[Block]: - return self.layers[1:-1] - @property def model_head(self) -> LanguageModelHead: return self.layers[self.model_head_indices[0]] @@ -369,54 +303,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - @property - def loss_defs(self) -> list[LossDef]: - loss_defs = [] - if ( - isinstance(self._config.transformer.mlp, MoEMLPConfig) - and self._config.transformer.mlp.experts > 1 - and self._config.transformer.mlp.routing == RoutingType.topk - ): - loss_defs.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=self._config.transformer.num_layers, - ) - ) - if self._config.transformer.mlp.z_loss_coefficient: - loss_defs.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=self._config.transformer.num_layers, - ) - ) - if self._config.output_layer.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) - - if self._config.output_layer.enable_dpo: - loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=1)) - - if self._config.output_layer.distillation_model is not None: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=1) - ) - if self._config.output_layer.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=1) - ) - - for i in range(self._config.output_layer.prediction_heads): - loss_defs.append( - LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", - count=1, - ) - ) - return loss_defs - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py deleted file mode 100644 index 7152666cf..000000000 --- a/fast_llm/models/ssm/config.py +++ /dev/null @@ -1,189 +0,0 @@ -import logging -import typing - -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import ( - GPTBaseModelConfig, - GPTBatchConfig, - GPTHuggingfaceCheckpointFormat, - PretrainedGPTModelConfig, -) -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel - from fast_llm.models.ssm.trainer import HybridSSMTrainer - -logger = logging.getLogger(__name__) - - -@config_class() -class HybridSSMBaseModelConfig(GPTBaseModelConfig): - _abstract = False - - ssm: SSMConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) - hybrid_block_layout: list[SSMBlockType] | None = Field( - default=None, - desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", - hint=FieldHint.core, - ) - default_mtp_type: SSMBlockType | None = Field( - default=None, - desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", - hint=FieldHint.optional, - ) - # TODO: Support combination of different SSM block types. - ssm_block_type: SSMBlockType | None = Field(init=False) - - def _validate(self): - if self.hybrid_block_layout is None: - with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers - - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) - logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - - super()._validate() - ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} - # TODO: Support combination of different SSM block types. - Assert.leq(len(ssm_block_types), 1) - self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None - - -class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "llamba" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import LLambaHuggingfaceCheckpointHandler - - return LLambaHuggingfaceCheckpointHandler - - -class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler - - return AprielSSMHuggingfaceCheckpointHandler - - -class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm_hybrid" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler - - return AprielSSMHHybridHuggingfaceCheckpointHandler - - -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler - - -@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) -class HybridSSMModelConfig(FastLLMModelConfig): - _abstract = False - model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate() - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( - LLambaHuggingfaceCheckpointFormat, - AprielSSMHuggingfaceCheckpointFormat, - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - ) - - @classmethod - def get_model_class(cls) -> type["HybridSSMModel"]: - from fast_llm.models.ssm.model import HybridSSMModel - - return HybridSSMModel - - @classmethod - def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: - from fast_llm.models.ssm.model import HybridSSMInferenceRunner - - logger.warning( - "HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled." - ) - - return HybridSSMInferenceRunner - - @classmethod - def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - - return HuggingfaceHybridSSMModelForCausalLM - - def _validate(self): - logger.warning( - "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." - ) - super()._validate() - - -@config_class() -class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): - _abstract = False - model: HybridSSMModelConfig = FieldUpdate() - - -@config_class(dynamic_type={RunnableConfig: "train_hybrid_ssm", TrainerConfig: "hybrid_ssm"}) -class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() - batch: GPTBatchConfig = FieldUpdate() - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["HybridSSMTrainer"]: - from fast_llm.models.ssm.trainer import HybridSSMTrainer - - return HybridSSMTrainer - - def _validate(self) -> None: - super()._validate() - if (name := self.model.base_model.output_layer.distillation_model) is None: - Assert.empty(self.reference_models) - else: - Assert.eq(self.reference_models.keys(), {name}) - if self.model.base_model.embeddings_layer.position_embeddings.enabled: - Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) - # if self.model.base_model.distillation_model is not None: - # # TODO: Support loss masking for distillation? - # assert not self.batch.use_loss_masking_spans - for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.output_layer.distillation_model) - # TODO: Support more LM head features. - Assert.none(reference_model.model.base_model.output_layer.cross_entropy_splits) - Assert.eq( - reference_model.model.base_model.embeddings_layer.vocab_parallel, - self.model.base_model.embeddings_layer.vocab_parallel, - ) - Assert.geq( - reference_model.model.base_model.output_layer.prediction_heads, - self.model.base_model.output_layer.prediction_heads, - ) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py deleted file mode 100644 index 999974ea3..000000000 --- a/fast_llm/models/ssm/conversion.py +++ /dev/null @@ -1,774 +0,0 @@ -import json -import os -import pathlib -import typing - -from transformers import PretrainedConfig - -from fast_llm.config import MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ( - ConstantExportParamConverter, - ConstantImportParamConverter, - IgnoreImportParamConverter, - IgnoreImportWeightConverter, - MappedConfigParamConverter, - ParamConverter, - RenameParamConverter, - SplitWeightConverter, - WeightConverter, -) -from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter -from fast_llm.models.ssm.config import ( - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielSSMHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - HybridSSMModelConfig, - LLambaHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.external.apriel_15b_hybrid import ( - configuration_ssm_hybrid_apriel15b, - modeling_ssm_hybrid_apriel15b, -) -from fast_llm.models.ssm.external.apriel_hybrid import configuration_ssm_hybrid_apriel, modeling_ssm_hybrid_apriel -from fast_llm.models.ssm.model import HybridSSMModel -from fast_llm.utils import Assert - - -class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - _default_block_type: str = SSMBlockType.mamba2_discrete.value - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = MappedConfigParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, - export_value=lambda x: [x_.value for x_ in x], - ) - return super()._create_config_converters() + [block_converter] - - -class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("ssm", "state_size"),), - export_names=( - ( - "ssm_cfg", - "d_state", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_v_heads"),), - export_names=( - ( - "ssm_cfg", - "n_v_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_qk_heads"),), - export_names=( - ( - "ssm_cfg", - "n_qk_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "expansion_factor"),), - export_names=( - ( - "ssm_cfg", - "expand", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "chunk_size"),), - export_names=( - ( - "ssm_cfg", - "chunk_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "add_bias_linear"),), - export_names=( - ( - "ssm_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("ssm", "activation_type"),), - export_names=( - ( - "ssm_cfg", - "activation", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - # ================================================ - # Mamba2 specific parameters: they dont exist in old checkpoints exported for discrete Mamba2, hence need backward compatibility - RenameParamConverter( - fast_llm_names=(("ssm", "dt_rank"),), - export_names=( - ( - "ssm_cfg", - "dt_rank", - ), - ), - ignore_missing=True, - default_value=None, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_min"),), - export_names=( - ( - "ssm_cfg", - "dt_min", - ), - ), - ignore_missing=True, - default_value=0.001, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_max"),), - export_names=( - ( - "ssm_cfg", - "dt_max", - ), - ), - ignore_missing=True, - default_value=0.1, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_init_floor"),), - export_names=( - ( - "ssm_cfg", - "dt_init_floor", - ), - ), - ignore_missing=True, - default_value=1e-4, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "dt_scale"),), - export_names=( - ( - "ssm_cfg", - "dt_scale", - ), - ), - ignore_missing=True, - default_value=1.0, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_xb"),), - export_names=( - ( - "ssm_cfg", - "d_xb", - ), - ), - ignore_missing=True, - default_value=None, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "conv_kernel_dimension"),), - export_names=( - ( - "ssm_cfg", - "d_conv", - ), - ), - ignore_missing=True, - default_value=4, - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] - - num_layers = self._model.config.base_model.transformer.num_layers - ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases - - for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - # ================================================ - # Mamba2 specific parameters - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False - ) - # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", - self._model.config.base_model, - ) - ) - - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model - ) - ) - - return converters - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - -class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _hf_prefix: str = "backbone" - architecture: typing.ClassVar[str] = "LlambaForCausalLM" - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value=RMSNormalizationConfig.dynamic_type_name, - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - # not using super() because LLamba model is called backbone in the checkpoints - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases - - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append( - WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - ) - converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) - else: - converters.append( - WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - ) - converters.append( - WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") - ) - - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias - ) - - for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", - f"{self._hf_prefix}.layers.{i}.mixer.z_bias", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - - # Norm - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias - ) - - # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") - - return converters - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): - """ - Lamba-like configs, pure SSM models. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "AprielSSMForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel15b.__file__ - configuration_file = configuration_ssm_hybrid_apriel15b.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( - configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig - ) - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), - fast_llm_value=RMSNormalizationConfig.dynamic_type_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = False - - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) - - for i in range(num_layers): - # Norm - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias - ) - - # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") - - return converters - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielSSMHHybridHuggingfaceCheckpointHandler( - CustomModelingExportMixin, - HybridModelCheckpointHandler, # handles the block structure parameter - CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers - CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers -): - """ - Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value - architecture: typing.ClassVar[str] = "AprielSSMHybridForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel.__file__ - configuration_file = configuration_ssm_hybrid_apriel.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = modeling_ssm_hybrid_apriel.AprielSSMHybridConfig - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_ssm_hybrid_apriel.AprielSSMHybridConfig", - "AutoModel": "modeling_ssm_hybrid_apriel.AprielSSMHybridModel", - "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel.AprielSSMHybridForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), - ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) - - -class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( - CustomModelingExportMixin, - HybridModelCheckpointHandler, # handles the block structure parameter - CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers - CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers -): - """ - Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. - """ - - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = AprielThinkerSSMHHybridHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value - _hf_prefix: str = "model" - architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - modeling_file = modeling_ssm_hybrid_apriel15b.__file__ - configuration_file = configuration_ssm_hybrid_apriel15b.__file__ - configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( - configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig - ) - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("auto_map",),), - export_value={ - "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", - "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", - "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", - }, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "d_inner"),), - export_names=(("ssm_cfg", "d_inner"),), - ), - IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), - ] - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py deleted file mode 100644 index 1d230bb67..000000000 --- a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py +++ /dev/null @@ -1,448 +0,0 @@ -import math -from typing import Optional - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torch_available, logging - -logger = logging.get_logger(__name__) - -if is_torch_available(): - import torch - - -def _compute_default_rope_parameters( - config: Optional[PretrainedConfig] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - **rope_kwargs, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - if config is not None and len(rope_kwargs) > 0: - raise ValueError( - "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " - f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" - ) - if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] - elif config is not None: - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) - return inv_freq, attention_factor - - -def _compute_yarn_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://arxiv.org/abs/2309.00071) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - rope_kwargs (`Dict`, *optional*): - BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # No need to keep BC with yarn, unreleased when this new pattern was created. - if len(rope_kwargs) > 0: - raise ValueError( - f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" - ) - - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - # Apriel: Use original max_position_embeddings instead of max_position_embeddings - max_position_embeddings = config.rope_scaling.get( - "original_max_position_embeddings", config.max_position_embeddings - ) - factor = config.rope_scaling["factor"] - - # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") - if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = config.rope_scaling.get("beta_fast") or 32 - beta_slow = config.rope_scaling.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): - """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - - return inv_freq, attention_factor - - -def _check_received_keys( - rope_type: str, - received_keys: set, - required_keys: set, - optional_keys: Optional[set] = None, - ignore_keys: Optional[set] = None, -): - """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present - if "type" in received_keys: - received_keys -= {"type"} - required_keys.add("rope_type") - - # Some models need to store model-specific keys, and we don't want to throw warning at them - if ignore_keys is not None: - received_keys -= ignore_keys - - missing_keys = required_keys - received_keys - if missing_keys: - raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") - - if optional_keys is not None: - unused_keys = received_keys - required_keys - optional_keys - else: - unused_keys = received_keys - required_keys - if unused_keys: - logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") - - -def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - -def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor", "original_max_position_embeddings"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - beta_fast = rope_scaling.get("beta_fast") - if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - beta_slow = rope_scaling.get("beta_slow") - if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - - if (beta_fast or 32) < (beta_slow or 1): - logger.warning( - f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " - f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" - ) - - -# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters -# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE -# parameterizations, as long as the callable has the same signature. -ROPE_INIT_FUNCTIONS = { - "default": _compute_default_rope_parameters, - "yarn": _compute_yarn_parameters, -} - -# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. -ROPE_VALIDATION_FUNCTIONS = { - "default": _validate_default_rope_parameters, - "yarn": _validate_yarn_parameters, -} - - -def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): - """ - Validate the RoPE config arguments, given a `PretrainedConfig` object - """ - rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` - if rope_scaling is None: - return - - # BC: "rope_type" was originally "type" - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) - validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) - if validation_fn is not None: - validation_fn(config, ignore_keys=ignore_keys) - else: - logger.warning( - f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) - - -class AprielSSMHybridConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Apriel-5B-Base. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`AprielModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_attention_heads - ```python - >>> from transformers import AprielModel, AprielConfig - >>> # Initializing an Apriel Apriel-5B-Base style configuration - >>> configuration = AprielConfig() - >>> # Initializing a model from the Apriel-5B-Base style configuration - >>> model = AprielModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "apriel_ssm_hybrid" - keys_to_ignore_at_inference = ["past_key_values"] - # Default tensor parallel plan for base model `AprielModel` - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - head_dim=None, - hybrid_block_layout=["m2d"], - ssm_cfg=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - self.hybrid_block_layout = hybrid_block_layout - if len(hybrid_block_layout) == 1: - self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers - assert len(self.hybrid_block_layout) == self.num_hidden_layers - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, copy it it to 'rope_type'. - if self.rope_scaling is not None and "type" in self.rope_scaling: - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - ssm_defaults = { - "d_state": 64, - "n_v_heads": 24, - "n_qk_heads": 24, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_conv": 4, - "d_inner": 24 * self.head_dim, # num_heads * head_dim - } - self.ssm_cfg = ssm_cfg or ssm_defaults - for k, v in ssm_defaults.items(): - if k not in self.ssm_cfg: - self.ssm_cfg[k] = v diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py deleted file mode 100644 index 771f81a7d..000000000 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ /dev/null @@ -1,1576 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Optional, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from torch import nn -from transformers import GenerationMixin -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from transformers.utils.generic import ModelOutput - -from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( - ROPE_INIT_FUNCTIONS, - AprielSSMHybridConfig, -) - -logger = logging.get_logger(__name__) - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - - -def apply_mask_to_padding_states(hidden_states, attention_mask): - """ - Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 - """ - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return hidden_states - - -class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): - super().__init__() # config, batch_size, max_length, device, dtype) - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - self.batch_size = batch_size - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - self.max_cache_len = config.max_position_embeddings if max_length is None else max_length - - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) - - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - new_layer_conv_state = torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - - new_layer_ssm_state = torch.zeros( - batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype - ) - new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) - new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) - else: - # Attention or MLP layer - new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - self.transformer_layers.append(i) - - # if not is_torchdynamo_compiling(): - # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # new_layer_key_cache = getattr(self, f"key_cache_{i}") - # new_layer_value_cache = getattr(self, f"value_cache_{i}") - # torch._dynamo.mark_static_address(new_layer_key_cache) - # torch._dynamo.mark_static_address(new_layer_value_cache) - # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) - # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) - # torch._dynamo.mark_static_address(new_layer_conv_state) - # torch._dynamo.mark_static_address(new_layer_ssm_state) - # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") - # new_layer_conv_state = getattr(self, f"conv_states_{i}") - - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - self.conv_states.append(new_layer_conv_state) - self.ssm_states.append(new_layer_ssm_state) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - cache_position = cache_kwargs.get("cache_position") - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = None) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx is None: - if len(self.transformer_layers) > 0: - layer_idx = self.transformer_layers[0] - else: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__() - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - self.conv_states += [ - torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - ] - self.ssm_states += [ - torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) - ] - else: - # Attention or MLP layer - self.conv_states += [torch.tensor([[]] * batch_size, device=device)] - self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -@dataclass -class AprielHybridCausalOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - attention_weights: Optional[torch.FloatTensor] = None - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None - - -class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): - """ - AprielRMSNorm is equivalent to T5LayerNorm - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) - - -class AprielMLP(nn.Module): - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class AprielRotaryEmbedding(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class AprielAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -def segsum(x): - """More stable segment sum calculation.""" - # [1, 2, 3] - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) - x = x.masked_fill(~mask, 0) - # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] - x_segsum = torch.cumsum(x, dim=-2) - # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def materialize_mixer(A_log, B, C, D): - """ - Since the transfer matrix will be equated to the attention matrix, - we need to support the form: torch.matmul(attn_weights, value_states). - Thus, y = torch.matmul(T, X) - Arguments: - A_log: (batch, length, n_heads) - B: (batch, length, n_heads, d_state) - C: (batch, length, n_heads, d_state) - Return: - T: (batch, n_heads, length, length) - """ - batch_size, length, n_heads, d_state = B.shape - assert A_log.shape == (batch_size, length, n_heads) - assert B.shape == C.shape == (batch_size, length, n_heads, d_state) - - # Compute: - A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") - powers = torch.exp(segsum(A_log)) - T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) - - # Add D: - if D is not None: - T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) - - T = rearrange(T, "b h z l -> b h l z") - return T - - -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - # In __init__, pre-allocate these tensors - self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) - self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def forward( - self, - u, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, - return_mixer_matrix=False, - **kwargs, - ): - """ - u: (B, L, D) - Returns: same shape as u - For later refference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modeling_bamba.py - """ - assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" - cache_position = kwargs.get("cache_position", None) - batch, seqlen, dim = u.shape - u = apply_mask_to_padding_states(u, attention_mask) - ssm_state, conv_state = None, None - - use_precomputed_states = ( - past_key_value is not None - and past_key_value.has_previous_state - and seqlen == 1 - and past_key_value.conv_states[self.layer_idx].shape[0] - == past_key_value.ssm_states[self.layer_idx].shape[0] - == batch - and cache_position is not None - and cache_position[0] > 0 - ) - ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if use_precomputed_states: - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _, _ = self.step(u, ssm_state, conv_state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - outputs = {} - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if ssm_state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if ssm_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(ssm_state is not None), - ) - - if ssm_state is not None: - y, ssm_state_update = result - ssm_state.copy_(ssm_state_update) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, ssm_state, conv_state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - if conv_state_new is not None: - raise NotImplementedError("Should not end up here snce only support fast path.") - # conv_state.copy_(conv_state_new) # update state in place, only for slow pass - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - ssm_state = ssm_state.to(x.dtype) - zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate - ones = self.ones_buffer.to(A_log.device).to(x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=ssm_state, # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, ssm_state, conv_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - # if self.layer_idx not in inference_params.ssm_states: - # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - # batch_size, inference_params.max_seqlen, dtype=torch.float32 - # ) - # Get states - ssm_states = inference_params.ssm_states[self.layer_idx] - conv_states = inference_params.conv_states[self.layer_idx] - if initialize_states: - ssm_states.zero_() - conv_states.zero_() - return ssm_states, conv_states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - return xBC, None - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - -class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) - - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class AprielSSMDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.hidden_size = config.hidden_size - - self.mixer = DiscreteMamba2( - d_model=config.hidden_size, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - self.mlp = AprielMLP(config, **factory_kwargs) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - - def forward( - self, hidden_states: torch.Tensor, **kwargs - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - outputs = {} - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - mixer_outputs = self.mixer( - hidden_states, - **kwargs, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - # outputs["hidden_states"] = hidden_states - outputs = (hidden_states,) - - return outputs - - # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - # """Allocate inference cache for the model.""" - # if getattr(self.mixer, "allocate_inference_cache", None) is None: - # return - # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - -APRIEL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`AprielSSMHybridConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMPreTrainedModel(PreTrainedModel): - config_class = AprielSSMHybridConfig - base_model_prefix = "model" - _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - # def allocate_inference_cache(self, *args, **kwargs): - # """Allocate inference cache for the model.""" - # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) - - -APRIEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMHybridModel(AprielSSMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] - Args: - config: AprielSSMHybridConfig - """ - - def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) - blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") - for layer_idx, type in enumerate(config.hybrid_block_layout): - if type == "m2d": - blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) - elif type == "t": - blocks.append(AprielDecoderLayer(config, layer_idx)) - else: - raise ValueError(f"Invalid block type: {type}") - self.layers = nn.ModuleList(blocks) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.gradient_checkpointing = False - self.rotary_emb = AprielRotaryEmbedding(config=config) - self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # def allocate_inference_cache(self, *args, **kwargs): - # """Allocate inference cache for the model.""" - # cache = {} - # for i, layer in enumerate(self.layers): - # if isinstance(layer, AprielSSMDecoderLayer): - # cache[i] = layer.allocate_inference_cache(*args, **kwargs) - # return cache - - @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - inference_params=None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - # past_key_values = HybridMambaAttentionDynamicCache() - logger.warning_once( - "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) - - if cache_position is None and self.has_transformer_layers: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None and self.has_transformer_layers: - position_ids = cache_position.unsqueeze(0) - - causal_mask = ( - self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) - if self.has_transformer_layers - else None - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - inference_params=inference_params, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - output = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - return output if return_dict else output.to_tuple() - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( - past_key_values, HybridMambaAttentionStaticCache - ) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.model = AprielSSMHybridModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - output_router_logits=False, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, - ): - # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - - empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if not empty_past_kv: - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device - ) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "output_router_logits": output_router_logits, - # "logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - } - ) - return model_inputs - - def forward( - self, - input_ids: torch.LongTensor = None, - position_ids=None, - return_hidden_states=False, - return_logits=True, - num_last_tokens=0, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, CausalLMOutputWithPast]: - - # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) - # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model - outputs: BaseModelOutputWithPast = self.model( - input_ids, - return_hidden_states=return_hidden_states, - position_ids=position_ids, - past_key_values=past_key_values, - **kwargs, - ) - - if outputs["last_hidden_state"] is not None and return_logits: - logits = self.lm_head(outputs["last_hidden_state"]).float() - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return AprielHybridCausalOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs.hidden_states, - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - ) - - -__all__ = [ - "AprielSSMHybridForCausalLM", - "AprielSSMHybridModel", - "AprielSSMPreTrainedModel", -] diff --git a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py deleted file mode 100644 index 6943a3124..000000000 --- a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Apriel SSM model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import is_torch_available, logging - -logger = logging.get_logger(__name__) - -if is_torch_available(): - pass - - -class AprielSSMConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Apriel-5B-Base. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - .... - ```""" - - model_type = "apriel_ssm" - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - hidden_act="silu", - initializer_range=0.02, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - mlp_bias=False, - rms_norm_eps=1e-5, - ssm_cfg: dict = None, - head_dim: int = 128, - **kwargs, - ): - self.vocab_size = vocab_size - # self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - # self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - # self.rope_theta = rope_theta - self.mlp_bias = mlp_bias - self.head_dim = head_dim - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, copy it it to 'rope_type'. - # if self.rope_scaling is not None and "type" in self.rope_scaling: - # self.rope_scaling["rope_type"] = self.rope_scaling["type"] - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - self.ssm_cfg = ssm_cfg or { - "d_state": 64, - "n_v_heads": 24, - "n_qk_heads": 24, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - "d_inner": 24 * self.head_dim, # num_heads * head_dim - } - if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: - logger.warning("Head dim is not equal to d_inner // n_qk_heads.") - - -__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py deleted file mode 100644 index 09dc8259c..000000000 --- a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py +++ /dev/null @@ -1,743 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.utils.generation import GenerationMixin -from torch import nn -from transformers.activations import ACT2FN -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from transformers.utils.generic import ModelOutput - -from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig - -logger = logging.get_logger(__name__) - - -@dataclass -class CustomMambaCausalLMOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - - -class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): - """ - AprielRMSNorm is equivalent to T5LayerNorm - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) - - -class AprielMLP(nn.Module): - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def segsum(x): - """More stable segment sum calculation.""" - # [1, 2, 3] - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) - x = x.masked_fill(~mask, 0) - # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] - x_segsum = torch.cumsum(x, dim=-2) - # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def materialize_mixer(A_log, B, C, D): - """ - Since the transfer matrix will be equated to the attention matrix, - we need to support the form: torch.matmul(attn_weights, value_states). - Thus, y = torch.matmul(T, X) - Arguments: - A_log: (batch, length, n_heads) - B: (batch, length, n_heads, d_state) - C: (batch, length, n_heads, d_state) - Return: - T: (batch, n_heads, length, length) - """ - batch_size, length, n_heads, d_state = B.shape - assert A_log.shape == (batch_size, length, n_heads) - assert B.shape == C.shape == (batch_size, length, n_heads, d_state) - - # Compute: - A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") - powers = torch.exp(segsum(A_log)) - T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) - - # Add D: - if D is not None: - T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) - - T = rearrange(T, "b h z l -> b h l z") - return T - - -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): - """ - u: (B, L, D) - Returns: same shape as u - """ - outputs = {} - # assert state is None - batch, seqlen, dim = u.shape - - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # States are updated inplace - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), - ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u.squeeze(1)) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - state["ssm"] = state["ssm"].to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=state["ssm"], # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) - # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] - if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - -class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): - super().__init__(**kwargs) - factory_kwargs = {"device": device, "dtype": dtype} - self.hidden_size = config.hidden_size - - self.mixer = DiscreteMamba2( - d_model=config.hidden_size, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - self.mlp = AprielMLP(config, **factory_kwargs) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - - def forward( - self, hidden_states: torch.Tensor, inference_params=None, **kwargs - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - outputs = {} - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - mixer_outputs = self.mixer( - hidden_states, - inference_params=inference_params, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs["hidden_states"] = hidden_states - - return outputs - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - -APRIEL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`AprielSSMConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMPreTrainedModel(PreTrainedModel): - config_class = AprielSSMConfig - base_model_prefix = "model" - _no_split_modules = ["AprielDecoderLayer"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) - - -APRIEL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Apriel Model outputting raw hidden-states without any specific head on top.", - APRIEL_START_DOCSTRING, -) -class AprielSSMModel(AprielSSMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] - Args: - config: AprielSSMConfig - """ - - def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) - self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} - - @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - return_hidden_states=False, - inference_params=None, - position_ids=None, - ) -> Union[tuple, BaseModelOutputWithPast]: - - hidden_states = self.embed_tokens(input_ids) - - # decoder layers - outputs = { - "last_hidden_state": None, - "all_hidden_states": (hidden_states,) if return_hidden_states else (), - } - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - - layer_outputs = decoder_layer( - hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - # Record outputs - hidden_states = layer_outputs["hidden_states"] - if return_hidden_states: - outputs["all_hidden_states"] += (hidden_states,) - - outputs["last_hidden_state"] = self.norm(hidden_states) - return outputs - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) - self.model = AprielSSMModel(config, device=device, dtype=dtype) - self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - position_ids=None, - return_hidden_states=False, - return_logits=True, - inference_params=None, - num_last_tokens=0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[tuple, CausalLMOutputWithPast]: - - outputs = self.model( - input_ids, - return_hidden_states=return_hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - - if outputs["last_hidden_state"] is not None and return_logits: - logits = self.lm_head(outputs["last_hidden_state"]).float() - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return CustomMambaCausalLMOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=outputs["last_hidden_state"], - ) - - def generate(self, *args, **kwargs): - """ - This is a wrapper to make sure we comply with the HF generation interface for eval harness - """ - return super().generate(*args, **kwargs) - - -__all__ = [ - "AprielSSMForCausalLM", - "AprielModel", - "AprielSSMPreTrainedModel", -] diff --git a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py deleted file mode 100644 index b8173b733..000000000 --- a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py +++ /dev/null @@ -1,94 +0,0 @@ -from enum import Enum - -from transformers.configuration_utils import PretrainedConfig - - -class StateUpdateKernel(Enum): - ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet - cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html - ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification - - -class MTPLlambaConfig(PretrainedConfig): - r"""Configuration class for the CustomMamba model. - - This configuration is used to instantiate the CustomMamba model according to the specified arguments, - defining the model architecture. - - Args: - vocab_size (`int`, *optional*, defaults to 128256): - Vocabulary size of the model. - tie_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - pad_vocab_size_multiple (`int`, *optional*, defaults to 8): - Pad the vocabulary size up to the next multiple of this value. - lm_head_bias (`bool`, *optional*, defaults to `False`): - Whether the LM head includes a bias term. - d_model (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - lm_head_prenorm (`str`, *optional*, defaults to "rms"): - Normalization type for LM head. - n_layer (`int`, *optional*, defaults to 32): - Number of layers in the model. - resid_dropout (`float`, *optional*, defaults to 0.0): - Dropout rate for residual connections. - norm_epsilon (`float`, *optional*, defaults to 1e-5): - Epsilon value used for normalization layers. - mlp_cfg (`dict`, *optional*): - Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. - ssm_cfg (`dict`, *optional*): - Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. - - """ - - model_type = "llamba" - - def __init__( - self, - vocab_size: int, - d_model: int, - tie_embeddings: bool = False, - pad_vocab_size_multiple: int = 8, - lm_head_bias: bool = False, - n_layer: int = 32, - resid_dropout: float = 0.0, - norm_epsilon: float = 1e-5, - mlp_cfg: dict = None, - ssm_cfg: dict = None, - prediction_heads=1, - state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, - **kwargs, - ): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.tie_embeddings = tie_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.lm_head_bias = lm_head_bias - self.d_model = d_model - self.n_layer = n_layer - self.resid_dropout = resid_dropout - self.norm_epsilon = norm_epsilon - self.prediction_heads = prediction_heads - assert ( - state_update_kernel != StateUpdateKernel.ssu_verification - ), "Only chunk scan and standard modes are supported for now" - self.state_update_kernel = state_update_kernel - - # MLP (Multi-Layer Perceptron) Config - self.mlp_cfg = mlp_cfg or { - "intermediate_size": 14336, - "bias": False, - "act_fn": "silu", - } - - # SSM (State Space Model) Config - self.ssm_cfg = ssm_cfg or { - "d_state": 64, - "n_v_heads": 32, - "n_qk_heads": 32, - "expand": 1, - "chunk_size": 128, - "activation": "identity", - "bias": False, - } diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py deleted file mode 100644 index 6d9746db1..000000000 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright (c) 2024, Kevin Li, Aviv Bick. - -import json -import os -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -from huggingface_hub import PyTorchModelHubMixin -from mamba_ssm.utils.generation import GenerationMixin -from torch import Tensor, nn -from transformers.activations import ACT2FN -from transformers.utils.generic import ModelOutput - -from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig -from .discrete_mamba2 import DiscreteMamba2 - - -class LlamaRMSNorm(nn.Module): - """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" - - def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """ - Args: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). - - Returns: - torch.Tensor of shape (batch_size, seq_len, hidden_size). - """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - """Set the extra representation of the module.""" - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class LlamaMLP(nn.Module): - """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" - - def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) - self.act_fn = ACT2FN[act_fn] - - def forward(self, x): - """ - Args: - x: torch.Tensor of shape (batch_size, seq_len, hidden_size). - - Returns: - torch.Tensor of shape (batch_size, seq_len, hidden_size). - """ - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -@dataclass -class CustomMambaCausalLMOutput(ModelOutput): - """Custom output class for MambaLMHeadModel.""" - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None - last_hidden_state: Optional[torch.FloatTensor] = None - - -class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): - """MambaLM model with a language modeling head on top (linear layer).""" - - def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: - super().__init__() - - # Load config - if not isinstance(config, LlambaConfig): - config = LlambaConfig(**config) - self.config = config - - # Factory kwargs - factory_kwargs = {"device": device, "dtype": dtype} - - # Pad vocab size to be a multiple of pad_vocab_size_multiple - vocab_size = config.vocab_size - pad_vocab_size_multiple = config.pad_vocab_size_multiple - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) - self.config.vocab_size = vocab_size - - # Mixer model - self.backbone = MixerModel( - input_size=vocab_size, - config=self.config, - initializer_cfg=initializer_cfg, - **factory_kwargs, - ) - - # MTP heads - self.mtp_heads = nn.ModuleList( - [ - Block( - config=config, - factory_kwargs=factory_kwargs, - layer_idx=layer_idx, - ).to(device) - for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) - ] - ) - - self.mtp_norms = nn.ModuleList( - [ - LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) - for _ in range(config.prediction_heads - 1) - ] - ) - # LM head - if not self.config.tie_embeddings: - self.lm_head = nn.Linear( - in_features=self.config.d_model, - out_features=self.config.vocab_size, - bias=self.config.lm_head_bias, - **factory_kwargs, - ) - else: - self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - - mtps = { - i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) - for i, layer in enumerate(self.mtp_heads) - } - return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} - - def forward( - self, - input_ids, - position_ids=None, - return_hidden_states=False, - return_logits=True, - inference_params=None, - num_last_tokens=0, - ): - """ - Args: - input_ids: torch.Tensor of shape (batch_size, seq_len), - position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), - return_hidden_states: bool, optional, - return_logits: bool, optional, whether to compute the logits with the LM head, - inference_params: dict, optional, the model's inference cache, - num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. - - Returns: - CustomMambaCausalLMOutput. - - """ - outputs = self.backbone( - input_ids, - return_hidden_states=return_hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - - # MTP heads processing - latents = [] - hidden_states = outputs["last_hidden_state"] - hidden_states_before_last = outputs["hidden_state_before_last"] - - # last layer already has layer norm applied - latents.append(hidden_states) - - # Process through MTP heads - for i, mtp_head in enumerate(self.mtp_heads): - mtp_outputs = mtp_head( - hidden_states_before_last, - inference_params=inference_params, - position_ids=position_ids, - ) - mtp_hidden_states = mtp_outputs["hidden_states"] - latents.append(self.mtp_norms[i](mtp_hidden_states)) - - # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) - stacked_latents = torch.stack(latents, dim=-2) - - if return_logits: - if isinstance(self.lm_head, nn.Linear): - # Apply lm_head to each prediction head's output - logits = self.lm_head(stacked_latents).float() - else: - # Using the tied embedding weights - logits = self.lm_head(stacked_latents) - - outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] - else: - outputs["logits"] = None - - return CustomMambaCausalLMOutput( - loss=None, - logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=stacked_latents, - ) - - def save_pretrained(self, save_directory): - """ - Minimal implementation of save_pretrained for MambaLMHeadModel. - Save the model and its configuration file to a directory. - """ - # Ensure save_directory exists - if not os.path.exists(save_directory): - os.makedirs(save_directory) - - # Save the model's state_dict - model_path = os.path.join(save_directory, "pytorch_model.bin") - torch.save(self.state_dict(), model_path) - - # Save the configuration of the model - config_path = os.path.join(save_directory, "config.json") - with open(config_path, "w") as f: - json.dump(self.config.to_dict(), f) - - -class MixerModel(nn.Module): - """Mixer model with a stack of Mixer layers.""" - - def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.config = config - self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) - - self.layers = nn.ModuleList( - [ - Block( - config=config, - factory_kwargs=factory_kwargs, - layer_idx=i, - ).to(device) - for i in range(self.config.n_layer) - ] - ) - - self.final_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, - eps=self.config.norm_epsilon, - factory_kwargs=factory_kwargs, - ) - - return - - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} - - def forward( - self, - input_ids, - return_hidden_states=False, - inference_params=None, - position_ids=None, - ): - """Run the model.""" - # Start running the layers - hidden_states = self.embedding(input_ids) - - # Initialize outputs - outputs = { - "last_hidden_state": None, - "hidden_state_before_last": None, - "all_hidden_states": (hidden_states,) if return_hidden_states else (), - } - - # Run the layers - for layer in self.layers: - layer_outputs = layer( - hidden_states, - inference_params=inference_params, - position_ids=position_ids, - ) - if layer == self.layers[-1]: - outputs["hidden_state_before_last"] = hidden_states - # Record outputs - hidden_states = layer_outputs["hidden_states"] - if return_hidden_states: - outputs["all_hidden_states"] += (hidden_states,) - - # Last layer, apply layer norm - outputs["last_hidden_state"] = self.final_layernorm(hidden_states) - return outputs - - -class Block(nn.Module): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - - def __init__(self, config, factory_kwargs, layer_idx, **kwargs): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - # Mixer - self.mixer = DiscreteMamba2( - d_model=self.config.d_model, - layer_idx=layer_idx, - **config.ssm_cfg, - **factory_kwargs, - ) - - # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) - self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs - ) - self.mlp = LlamaMLP( - hidden_size=self.config.d_model, - **config.mlp_cfg, - factory_kwargs=factory_kwargs, - ) - - def forward( - self, - hidden_states: Tensor, - inference_params=None, - **kwargs, - ): - """ - Pass the input through the encoder layer. - - Args: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - inference_params: dict, optional, - - Returns: - dict with keys: - hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), - transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). - """ - outputs = {} - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Apply Mixer - mixer_outputs = self.mixer( - hidden_states, - inference_params=inference_params, - ) - - hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs["hidden_states"] = hidden_states - - return outputs - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py deleted file mode 100644 index 24005ee9f..000000000 --- a/fast_llm/models/ssm/huggingface.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging -import typing - -from fast_llm.engine.inference.config import HuggingfaceModelConfig -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel - -logger = logging.getLogger(__name__) - - -class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): - model_type = "fast_llm_ssm" - model_config_class = HybridSSMModelConfig - fast_llm_config: HybridSSMModelConfig - - -class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): - config_class = HuggingfaceSSMModelConfig - config: HuggingfaceSSMModelConfig - model_class = HybridSSMModel - runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner - _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py deleted file mode 100644 index 0382462b5..000000000 --- a/fast_llm/models/ssm/model.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -import typing - -from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType - -logger = logging.getLogger(__name__) - - -class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): - """ - A hybrid model that interleaves Transformer and Mamba blocks. - Right now only LlambaBlock is supported. - As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. - """ - - def _get_block( - self, - block_index: int, - name: str, - return_input: bool = False, - ): - if block_index > self._config.transformer.num_layers: - # MTP block - block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - else: - # Decoder block - block_type = self._config.hybrid_block_layout[block_index - 1] - - if block_type == SSMBlockType.transformer: - block_config = self._config.transformer - else: - block_config = self._config.transformer.from_dict(self._config.transformer, {"mixer": self._config.ssm}) - - return block_config.get_layer( - self._distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - return_input=return_input, - ) - - -class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): - """ - A hybrid model that combines Transformer and SSM blocks. - """ - - base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel - - -class HybridSSMInferenceRunner(GPTInferenceRunner): - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py deleted file mode 100644 index 39f589384..000000000 --- a/fast_llm/models/ssm/trainer.py +++ /dev/null @@ -1,9 +0,0 @@ -import typing - -from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridSSMTrainerConfig -from fast_llm.models.ssm.model import HybridSSMModel - - -class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 4323efe3f..b709ea835 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -198,12 +198,12 @@ def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tenso assert not self._reductions if tensor.ndim == 0: tensor = tensor[None] - Assert.eq(tensor.shape, self.global_shape) + Assert.eq(tensor.shape, self.global_shape, msg=self) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim) - Assert.eq(tensor.shape, self.shape) + Assert.eq(tensor.shape, self.shape, msg=self) return tensor @classmethod diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d13ecaf65..1f9feceb4 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -93,8 +93,9 @@ class Assert: @staticmethod def eq(x, *args, msg=None): + assert args for arg in args: - assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") + assert x == arg, f"{x} != {arg} " + ("" if msg is None else f"| {msg}") @staticmethod def is_(x, y): @@ -457,3 +458,15 @@ def get_and_reset_memory_usage_mib( _global_max_reserved = max(max_reserved, _global_max_reserved) return report + + +def safe_merge_dicts(*dicts) -> dict: + out = {} + for dict_ in dicts: + for key, value in dict_.items(): + if key in out: + if isinstance(value, dict) and isinstance(out[key], dict): + out[key] = safe_merge_dicts(value, out[key]) + Assert.eq(value, out[key]) + out[key] = value + return out diff --git a/fast_llm_external_models/__init__.py b/fast_llm_external_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py similarity index 89% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py rename to fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py index 98d2fc28d..12ee343ef 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/configuration_apriel_hybrid_ssm.py @@ -23,11 +23,13 @@ "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, + "dt_proj_bias": True, + "repeat_kv_before_conv": True, } -class AprielSSMHybridConfig(MistralConfig): - model_type = "apriel_ssm_thinker_hybrid" +class AprielHybridSSMConfig(MistralConfig): + model_type = "apriel_hybrid_ssm" def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py similarity index 98% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py rename to fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 4fde72458..5c0a2216c 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -21,7 +21,7 @@ from transformers.utils import LossKwargs, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig # from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn # from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn @@ -46,7 +46,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + def __init__(self, config: AprielHybridSSMConfig, batch_size, max_length, dtype=torch.float16, device=None): super().__init__() # config, batch_size, max_length, device, dtype) self.dtype = dtype self.hybrid_override_pattern = config.hybrid_block_layout @@ -231,7 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + def __init__(self, config: AprielHybridSSMConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.dtype = dtype self.hybrid_override_pattern = config.hybrid_block_layout @@ -564,8 +564,7 @@ def forward( else: seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0 use_precomputed_states = ( - past_key_value is not None - and past_key_value.has_previous_state + getattr(past_key_value, "has_previous_state", False) and seqlen == 1 and past_key_value.conv_states[self.layer_idx].shape[0] == past_key_value.ssm_states[self.layer_idx].shape[0] @@ -1130,7 +1129,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states class AprielSSMDecoderLayer(nn.Module): _mixer_class = DiscreteMamba2 - def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): super().__init__(**kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size @@ -1179,7 +1178,7 @@ class AprielSSMM2DecoderLayer(AprielSSMDecoderLayer): class AprielHybridIdentity(nn.Module): - def __init__(self, config: AprielSSMHybridConfig): + def __init__(self, config: AprielHybridSSMConfig): super().__init__() self.config = config @@ -1187,14 +1186,14 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return (hidden_states,) -class AprielThinkerSSMHybridModel(MistralModel): +class AprielHybridSSMModel(MistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] Args: - config: AprielSSMHybridConfig + config: AprielHybridSSMConfig """ - def __init__(self, config: AprielSSMHybridConfig, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, **kwargs): config_copy = copy.deepcopy(config) config_copy.num_hidden_layers = 0 super().__init__(config_copy, **kwargs) @@ -1221,8 +1220,8 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel): - config_class = AprielSSMHybridConfig +class AprielHybridSSMPreTrainedModel(PreTrainedModel): + config_class = AprielHybridSSMConfig base_model_prefix = "model" _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] @@ -1248,13 +1247,13 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -class AprielThinkerSSMHybridForCausalLM(AprielThinkerSSMHybridPreTrainedModel, GenerationMixin): +class AprielHybridSSMForCausalLM(AprielHybridSSMPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: AprielSSMHybridConfig, **kwargs): + def __init__(self, config: AprielHybridSSMConfig, **kwargs): super().__init__(config, **kwargs) - self.model = AprielThinkerSSMHybridModel(config) + self.model = AprielHybridSSMModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1419,7 +1418,7 @@ def forward( __all__ = [ - "AprielThinkerSSMHybridForCausalLM", - "AprielThinkerSSMHybridModel", - "AprielThinkerSSMHybridPreTrainedModel", + "AprielHybridSSMForCausalLM", + "AprielHybridSSMModel", + "AprielHybridSSMPreTrainedModel", ] diff --git a/fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py b/fast_llm_external_models/diffusion_dream/configuration_dream.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py rename to fast_llm_external_models/diffusion_dream/configuration_dream.py diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_config.json b/fast_llm_external_models/diffusion_dream/generation_config.json similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/generation_config.json rename to fast_llm_external_models/diffusion_dream/generation_config.json diff --git a/fast_llm/models/gpt/external/diffusion_dream/generation_utils.py b/fast_llm_external_models/diffusion_dream/generation_utils.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_dream/generation_utils.py rename to fast_llm_external_models/diffusion_dream/generation_utils.py diff --git a/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py b/fast_llm_external_models/diffusion_dream/modeling_dream.py similarity index 95% rename from fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py rename to fast_llm_external_models/diffusion_dream/modeling_dream.py index e041d6189..714576eeb 100644 --- a/fast_llm/models/gpt/external/diffusion_dream/modeling_dream.py +++ b/fast_llm_external_models/diffusion_dream/modeling_dream.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,37 +19,26 @@ """PyTorch Dream model.""" import math -from typing import List, Optional, Tuple, Union import os +from dataclasses import dataclass +from typing import Optional, Union + import torch import torch.utils.checkpoint from torch import nn -from dataclasses import dataclass - +from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPast, - MaskedLMOutput, -) -from transformers.utils import ModelOutput +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) -from transformers import PretrainedConfig +from transformers.utils import ModelOutput, is_flash_attn_2_available, logging + from .configuration_dream import DreamConfig -from .generation_utils import DreamGenerationMixin, DreamGenerationConfig +from .generation_utils import DreamGenerationConfig, DreamGenerationMixin if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import _flash_attention_forward - from flash_attn import flash_attn_with_kvcache, flash_attn_func + from flash_attn import flash_attn_with_kvcache logger = logging.get_logger(__name__) @@ -131,7 +119,6 @@ def reset_parameters(self): inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids, device): """ @@ -287,8 +274,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # Luke: Computing K and Vs for all tokens upto now q_len w/o using cache ? @@ -345,7 +332,7 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - + class DreamSdpaAttention(DreamAttention): """ @@ -364,9 +351,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 is_causal: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -384,7 +371,6 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -405,18 +391,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: @@ -435,9 +419,9 @@ def forward( value_states, attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -445,6 +429,7 @@ def forward( return attn_output, None, past_key_value + class DreamFlashAttention(DreamAttention): """ Dream attention module using Flash attention 2. @@ -460,9 +445,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 is_causal: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -478,7 +463,7 @@ def forward( ) bsz, q_len, _ = hidden_states.size() - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -489,7 +474,7 @@ def forward( # print(f"hidden_states: {hidden_states.shape} query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") # print(f"position_ids {position_ids} {position_ids.shape}") - + if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -500,18 +485,16 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # if query_states.device.type == "cuda" and attention_mask is not None: # query_states = query_states.contiguous() # key_states = key_states.contiguous() @@ -529,19 +512,19 @@ def forward( # attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, # dropout_p=self.attention_dropout if self.training else 0.0, # is_causal=False, # hard coded - # ) - + # ) + # print(f"query_states {query_states.shape} key_states {key_states.shape} value_states {value_states.shape}") - + # replacing with flash attention attn_output = flash_attn_with_kvcache( # q dim (batch_size, seqlen, nheads, headdim) q=query_states.transpose(1, 2).contiguous(), k_cache=key_states.transpose(1, 2).contiguous(), v_cache=value_states.transpose(1, 2).contiguous(), - causal=is_causal, # hard coded + causal=is_causal, # hard coded softmax_scale=1.0 / math.sqrt(self.head_dim), - ) + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -549,6 +532,7 @@ def forward( return attn_output, None, past_key_value + class DreamDecoderLayer(nn.Module): def __init__(self, config: DreamConfig, layer_idx: int): super().__init__() @@ -559,7 +543,7 @@ def __init__(self, config: DreamConfig, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - + # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) if config._attn_implementation == "flash_attention_2": self.self_attn = DreamFlashAttention(config, layer_idx) @@ -575,13 +559,13 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -605,9 +589,9 @@ def forward( """ # print(f"DreamDecoderLayer: past_key_value {past_key_value} use_cache {use_cache}") - + is_casual = kwargs.get("is_casual", False) - + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -643,9 +627,10 @@ def forward( # When use_cache is True, outputs will have length: # - 2 if output_attentions is False (hidden_states, present_key_value) # - 3 if output_attentions is True (hidden_states, self_attn_weights, present_key_value) - # print(f"DreamDecoderLayer: outputs {len(outputs)}") + # print(f"DreamDecoderLayer: outputs {len(outputs)}") return outputs + class DreamPreTrainedModel(PreTrainedModel): config_class = DreamConfig base_model_prefix = "model" @@ -700,7 +685,7 @@ def from_pretrained( **kwargs, ) # NOTE(Lin): we need to override the generation config - # because the generation config loaded in `from_pretrained` + # because the generation config loaded in `from_pretrained` # does not include all the attributes of DreamGenerationConfig resume_download = kwargs.get("resume_download", None) proxies = kwargs.get("proxies", None) @@ -722,6 +707,7 @@ def from_pretrained( ) return _model + class DreamBaseModel(DreamPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`] @@ -758,7 +744,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -766,13 +752,13 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, is_casual: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutput]: + ) -> Union[tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - + # print("DreamBaseModel: past_key_values", past_key_values, "use_cache", use_cache,) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -789,7 +775,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -867,6 +853,8 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) + + @dataclass class MaskedLMOutputWithPast(ModelOutput): """ @@ -891,16 +879,17 @@ class MaskedLMOutputWithPast(ModelOutput): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - past_key_values: Optional[Tuple[Cache]] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[tuple[Cache]] = None + class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -942,7 +931,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -952,7 +941,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **loss_kwargs, - ) -> Union[Tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedLMOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -993,4 +982,4 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, past_key_values=outputs.past_key_values, - ) \ No newline at end of file + ) diff --git a/fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/configuration_diffusion_llama.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py rename to fast_llm_external_models/diffusion_llama/configuration_diffusion_llama.py diff --git a/fast_llm/models/gpt/external/diffusion_llama/generation_utils.py b/fast_llm_external_models/diffusion_llama/generation_utils.py similarity index 100% rename from fast_llm/models/gpt/external/diffusion_llama/generation_utils.py rename to fast_llm_external_models/diffusion_llama/generation_utils.py diff --git a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py similarity index 99% rename from fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py rename to fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py index 5e613093e..c8723af5d 100644 --- a/fast_llm/models/gpt/external/diffusion_llama/modeling_diffusion_llama.py +++ b/fast_llm_external_models/diffusion_llama/modeling_diffusion_llama.py @@ -1,7 +1,3 @@ -import math -import os -from dataclasses import dataclass - # Copyright 2022 ServiceNow. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,6 +16,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import math +import os +from dataclasses import dataclass from typing import Callable, Optional, Union import torch @@ -30,20 +30,12 @@ from transformers.cache_utils import Cache, DynamicCache from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_flash_attention_utils import FlashAttentionKwargs - -# from transformers.modeling_layers import GradientCheckpointingLayer # Update transformer from transformers.modeling_outputs import BaseModelOutputWithPast, MaskedLMOutput from transformers.modeling_rope_utils import dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.processing_utils import Unpack from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( # auto_docstring - LossKwargs, - ModelOutput, - can_return_tuple, - is_torch_flex_attn_available, - logging, -) +from transformers.utils import ModelOutput, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_diffusion_llama import ROPE_INIT_FUNCTIONS, DiffusionLlamaConfig from .generation_utils import SLAMGenerationConfig, SLAMGenerationMixin diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm_external_models/eval/apriel_eval_wrapper.py similarity index 93% rename from fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py rename to fast_llm_external_models/eval/apriel_eval_wrapper.py index ee2c83e03..2405175b6 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm_external_models/eval/apriel_eval_wrapper.py @@ -54,13 +54,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + from fast_llm_external_models.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm_external_models.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM # Ensure we're using the correct device device = _get_device() @@ -121,13 +121,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + from fast_llm_external_models.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + from fast_llm_external_models.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM # Ensure we're using the correct device device = _get_device() @@ -194,15 +194,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( - AprielSSMHybridConfig, - ) + from fast_llm_external_models.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + from fast_llm_external_models.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielThinkerSSMHybridForCausalLM, ) diff --git a/fast_llm/models/ssm/external/eval/run_evalchemy.py b/fast_llm_external_models/eval/run_evalchemy.py similarity index 66% rename from fast_llm/models/ssm/external/eval/run_evalchemy.py rename to fast_llm_external_models/eval/run_evalchemy.py index 1cbb5b4da..2758c9ee1 100644 --- a/fast_llm/models/ssm/external/eval/run_evalchemy.py +++ b/fast_llm_external_models/eval/run_evalchemy.py @@ -1,5 +1,6 @@ from eval.eval import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + +from fast_llm_external_models.eval.apriel_eval_wrapper import ( # noqa: F401 AprielHybrid15bSSMWrapper, AprielHybridSSMWrapper, AprielSSMWrapper, diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm_external_models/eval/run_lm_eval.py similarity index 67% rename from fast_llm/models/ssm/external/eval/run_lm_eval.py rename to fast_llm_external_models/eval/run_lm_eval.py index 53c0febab..8d37584c4 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm_external_models/eval/run_lm_eval.py @@ -1,6 +1,6 @@ from lm_eval.__main__ import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 +from fast_llm_external_models.eval.apriel_eval_wrapper import ( # noqa: F401 AprielHybrid15bSSMWrapper, AprielHybridSSMWrapper, AprielSSMWrapper, diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py b/fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py similarity index 96% rename from fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py rename to fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py index dde11cfbc..f5d09da61 100644 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py +++ b/fast_llm_external_models/make_hybrid_checkpoint_with_importance_15b_mil.py @@ -3,8 +3,8 @@ import transformers from transformers import AutoConfig, AutoModelForCausalLM -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( +from fast_llm_external_models.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm_external_models.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielSSMM2DecoderLayer, AprielThinkerSSMHybridForCausalLM, ) diff --git a/fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py b/fast_llm_external_models/mtp_llama/configuration_mtp_llama.py similarity index 100% rename from fast_llm/models/gpt/external/mtp_llama/configuration_mtp_llama.py rename to fast_llm_external_models/mtp_llama/configuration_mtp_llama.py diff --git a/fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py similarity index 100% rename from fast_llm/models/gpt/external/mtp_llama/modeling_mtp_llama.py rename to fast_llm_external_models/mtp_llama/modeling_mtp_llama.py diff --git a/setup.cfg b/setup.cfg index 843aa15ca..77073ab55 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,9 @@ name = fast_llm [options] -packages = find_namespace: +packages = + fast_llm + fast_llm_external_models include_package_data = True python_requires = >=3.12 install_requires = diff --git a/tests/conftest.py b/tests/conftest.py index 86937326c..58301919f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,10 @@ from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip +# Import all dynamic classes. +import fast_llm.cli # isort: skip + + logger = logging.getLogger(__name__) manager: DependencyManager | None = None diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index e402659b0..d52564cc0 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -162,11 +162,13 @@ def test_lm_head( ): config = GPTBaseModelConfig.from_dict( { - "transformer": { + "decoder": { + "num_blocks": 0, + }, + "embeddings_layer": { + "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, - "num_layers": 0, }, - "embeddings_layer": {"vocab_size": VOCAB_SIZE}, "output_layer": { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, @@ -239,7 +241,7 @@ def test_lm_head( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device ) - .normal_(config.transformer.hidden_size**-0.5) + .normal_(config.embeddings_layer.hidden_size**-0.5) .requires_grad_(True) ) kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 51687c6d8..217ecd0e1 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -63,6 +63,8 @@ def main(args: list[str] | None = None) -> None: group = pool.get_process_group(range(world_size), rank) for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): + if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: + continue config = config.resolve(base_path, model_testing_config) Assert.eq(world_size, config.num_gpus) with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest: diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index ed911fc8a..714abc130 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -119,31 +119,38 @@ def test_conversion(model_testing_config, run_conversion, get_convert_path): DistributedCheckpointFormat, FastLLMCheckpointFormat, ) - run_conversion( - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), - FastLLMCheckpointFormat, - model_testing_config.checkpoint_format, - ) - run_conversion( - get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), - model_testing_config.checkpoint_format, - DistributedCheckpointFormat, - ) - run_conversion( - get_convert_path(), - DistributedCheckpointFormat, - model_testing_config.checkpoint_format, - ) - run_conversion( - get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), - model_testing_config.checkpoint_format, - FastLLMCheckpointFormat, - ) - run_conversion( - get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), - FastLLMCheckpointFormat, - DistributedCheckpointFormat, - ) + if model_testing_config.checkpoint_format is None: + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, + ) + else: + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + model_testing_config.checkpoint_format, + DistributedCheckpointFormat, + ) + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + model_testing_config.checkpoint_format, + FastLLMCheckpointFormat, + ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, + ) def _compare_safetensor_files( @@ -170,20 +177,29 @@ def _compare_safetensor_files( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_round_trip(model_testing_config, get_convert_path): # Test that the various possible conversion paths yield identical results. - _compare_safetensor_files( - get_convert_path() / "rank_0.safetensors", - get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", - get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", - expected_keys={_WEIGHT_SHARD_SAVE_NAME}, - ) - _compare_safetensor_files( - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", - get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", - ) - _compare_safetensor_files( - get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", - get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", - ) + if model_testing_config.checkpoint_format is None: + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + else: + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) + / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", + ) + _compare_safetensor_files( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) + / "model_0.safetensors", + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", + ) def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): @@ -223,6 +239,24 @@ def test_load_pretrained( reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + reference_config, + reference_shard, + ) + if model_testing_config.checkpoint_format is None: + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat), + reference_config, + reference_shard, + ) + return + reference_config_from_hf = model_testing_config.model_config_class.from_dict( { "base_model": yaml.safe_load( @@ -234,10 +268,6 @@ def test_load_pretrained( ) _compare_architectures(reference_config, reference_config_from_hf) - reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ - _WEIGHT_SHARD_SAVE_NAME - ] - load_and_compare_checkpoints(DistributedCheckpointFormat, get_convert_path(), reference_config, reference_shard) load_and_compare_checkpoints( @@ -253,12 +283,6 @@ def test_load_pretrained( reference_shard, ) - load_and_compare_checkpoints( - FastLLMCheckpointFormat, - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), - reference_config, - reference_shard, - ) load_and_compare_checkpoints( FastLLMCheckpointFormat, get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), @@ -284,6 +308,8 @@ def test_load_pretrained( @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): + if model_testing_config.checkpoint_format is None: + return # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? @@ -354,7 +380,8 @@ def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_ import tests.models.distributed_test_checkpoint script = [ - tests.models.distributed_test_checkpoint.__file__, + "-m", + tests.models.distributed_test_checkpoint.__name__, str(run_test_script_base_path), model_testing_config.name, ] @@ -388,6 +415,11 @@ def test_load_parallel_checkpoint_in_single_gpu( reference_distributed_shard, report_subtest, ): + if ( + model_testing_config.checkpoint_format is None + and distributed_save_load_config.load_format == "{checkpoint_format}" + ): + return # This should only happen when test is skipped (failed dependency). assert reference_distributed_shard is not None distributed_save_load_config = distributed_save_load_config.resolve( @@ -416,11 +448,8 @@ def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_b .resolve(base_path=run_test_script_base_path, model_testing_config=model_testing_config) .save_path / f"{DistributedCheckpointFormat.name}/rank_{rank}.safetensors" - for format_ in ( - DistributedCheckpointFormat.name, - FastLLMCheckpointFormat.name, - "{checkpoint_format}", - ) + for format_ in (DistributedCheckpointFormat.name, FastLLMCheckpointFormat.name) + + (() if model_testing_config.checkpoint_format is None else ("{checkpoint_format}",)) ] ) diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f8..ad0de47e6 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -7,7 +7,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig +from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -55,14 +56,14 @@ def _get_hf_model(model_path: str, use_flash_attention: bool, use_bf16: bool): def _get_fast_llm_model( - model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaCheckpointFormat ): updates = {} if use_flash_attention: - updates[("base_model", "transformer", "use_flash_attention")] = True + updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = True updates[("distributed", "training_dtype")] = "bf16" else: - updates[("base_model", "transformer", "use_flash_attention")] = False + updates[("base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: updates[("distributed", "training_dtype")] = "bf16" return HuggingfaceGPTModelForCausalLM.from_pretrained( @@ -76,7 +77,7 @@ def _get_fast_llm_model( def _get_fast_llm_model_from_model( - model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaCheckpointFormat ): updates = { ("pretrained", "path"): model_path, @@ -85,10 +86,10 @@ def _get_fast_llm_model_from_model( } if use_flash_attention: - updates[("model", "base_model", "transformer", "use_flash_attention")] = True + updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = True updates[("model", "distributed", "training_dtype")] = "bf16" else: - updates[("model", "base_model", "transformer", "use_flash_attention")] = False + updates[("model", "base_model", "decoder", "block", "mixer", "use_flash_attention")] = False if use_bf16: updates[("model", "distributed", "training_dtype")] = "bf16" @@ -227,7 +228,7 @@ def test_generate( ): _test_generate( model_path, - LlamaGPTHuggingfaceCheckpointFormat, + LlamaCheckpointFormat, use_flash_attention, use_bf16, max_new_tokens, @@ -311,9 +312,7 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) def test_generate_from_model( model_path, ): - _test_generate_from_model( - model_path, AutoTokenizer.from_pretrained(model_path), LlamaGPTHuggingfaceCheckpointFormat - ) + _test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat) @requires_cuda @@ -353,16 +352,14 @@ def _test_forward_return_hidden_states( ) # hidden_states include embeddings layer - assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers - ) + assert len(res_fast_llm.hidden_states) - 1 == len(fast_llm_model.config.fast_llm_config.base_model.decoder) @pytest.mark.extra_slow @requires_cuda def test_forward_return_hidden_states(model_path): _test_forward_return_hidden_states( - model_path, LlamaGPTHuggingfaceCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size + model_path, LlamaCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 5c4897646..d14721142 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -57,7 +57,12 @@ def test_and_compare_model( def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): import tests.models.distributed_test_model - script = [tests.models.distributed_test_model.__file__, str(run_test_script_base_path), model_testing_config.name] + script = [ + "-m", + tests.models.distributed_test_model.__name__, + str(run_test_script_base_path), + model_testing_config.name, + ] if request.config.getoption("distributed_capture"): logger.warning( "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." diff --git a/tests/test_config.py b/tests/test_config.py index 03d535520..4e73569b3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,20 +74,24 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "transformer": { - "mixer": { - "rotary": {"type": "default"}, - "window_size": 32, - "head_groups": 4, - }, - "mlp": { - "intermediate_size": 4096, # Implicit default, default value - "activation": "silu", # Implicit default, non-default value - }, - "normalization": {"type": "rms_norm"}, # Nested - "num_layers": 12, # Default + "embeddings_layer": { "hidden_size": 1024, # Default }, + "decoder": { + "block": { + "mixer": { + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 4, + }, + "mlp": { + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value + }, + "normalization": {"type": "rms_norm"}, # Nested + }, + "num_blocks": 12, # Default + }, "output_layer": {"tied_weight": False}, }, "multi_stage": {"zero_stage": 3}, @@ -101,15 +105,16 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "transformer": { - "mixer": { - "head_groups": 1, # Override to default + "embeddings_layer": {"hidden_size": 512, "vocab_size": 1000}, + "decoder": { + "block": { + "mixer": { + "head_groups": 1, # Override to default + }, + # rotary: Don't override nested. + "normalization": {"implementation": "triton"}, # Update non-default nested }, - # rotary: Don't override nested. - "normalization": {"implementation": "triton"}, # Update non-default nested - "hidden_size": 512, # Override, affects derived value (kv channels) }, - "embeddings_layer": {"vocab_size": 1000}, "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type } pretrained_config = PretrainedGPTModelConfig.from_dict( @@ -129,36 +134,43 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "transformer": { - "mixer": { - "type": "attention", - "rotary": {"type": "default"}, - "window_size": 32, - "head_groups": 1, - }, - "mlp": { - "type": "mlp", - "intermediate_size": 4096, # Implicit default, default value - "activation": "silu", # Implicit default, non-default value - }, - "normalization": {"type": "rms_norm", "implementation": "triton"}, - "num_layers": 12, + "embeddings_layer": { "hidden_size": 512, + "vocab_size": 1000, + }, + "decoder": { + "type": "fixed", + "block": { + "type": "decoder", + "mixer": { + "type": "attention", + "rotary": {"type": "default"}, + "window_size": 32, + "head_groups": 1, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 4096, # Implicit default, default value + "activation": "silu", # Implicit default, non-default value + }, + "normalization": {"type": "rms_norm", "implementation": "triton"}, + }, + "num_blocks": 12, }, - "embeddings_layer": {"vocab_size": 1000}, "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, "peft": {"type": "lora", "freeze_others": False}, } else: - base_model_update["transformer"]["normalization"]["type"] = "layer_norm" - base_model_update["transformer"]["mixer"]["type"] = "attention" - base_model_update["transformer"]["mixer"]["rotary"] = {"type": "none"} - base_model_update["transformer"]["mlp"] = {"type": "mlp"} + base_model_update["decoder"]["type"] = "fixed" + base_model_update["decoder"]["block"]["type"] = "decoder" + base_model_update["decoder"]["block"]["normalization"]["type"] = "layer_norm" + base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" + base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} + base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update - print("IKEUFGH", serialized_config, expected_config) check_equal_nested(serialized_config, expected_config) @@ -276,6 +288,6 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline # Check that the global ranks are partitioned into disjoint groups for each distributed dimension, # and indirectly that `DistributedDim.global_ranks` is consistent between ranks. Assert.eq(sum(len(global_ranks) for global_ranks in global_ranks_set), world_size) - Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks})) + Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 1b49dcfcc..cc5a60a8a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -1,59 +1,79 @@ +import copy + import pytest from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.block.block import Block +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda -def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: - cls = TrainerConfig.get_subclass(model_type) - parsed, unparsed = cls._get_parser().parse_known_args(args) - config: TrainerConfig = cls._from_parsed_args(parsed, unparsed) - distributed = Distributed(config.model.distributed) - trainer = config.get_trainer_class()(config=config) - trainer.setup(distributed, config.get_run(distributed)) - return trainer +def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: + cls = FastLLMModelConfig.get_subclass(model_type) + config: FastLLMModelConfig = cls.from_dict(config_dict) + model = config.get_model_class()(config) + model.setup(Distributed(config.distributed)) + return model @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): get_model_test_dataset() - args = model_testing_config.config_args + ["run.tensor_logs.save=False"] - model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage - model_frozen = _get_trainer_from_args( - args + [f"model.base_model.transformer.mlp.lr_scale=0"], - model_testing_config.model_type, - )._multi_stage + frozen_config_dict = copy.deepcopy(model_testing_config.config_dict) + decoder_config = frozen_config_dict["model"]["base_model"]["decoder"] + if (decoder_type := decoder_config.get("type", "fixed")) == "fixed": + decoder_config["block"]["mlp"]["lr_scale"] = 0 + elif decoder_type == "pattern": + for block_config in decoder_config["blocks"].values(): + block_config["mlp"]["lr_scale"] = 0 + else: + raise NotImplementedError(decoder_type) + + model_ref = _get_model(model_testing_config.config_dict["model"], model_testing_config.model_type) + model_frozen = _get_model(frozen_config_dict["model"], model_testing_config.model_type) Assert.eq( model_ref._num_stages, model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, Block) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, DecoderBlock) else 0 for layer in model_ref.base_model.layers ] - for weight_buffer_ref, weight_buffer_frozen in zip( - model_ref._weight_buffers, model_frozen._weight_buffers, strict=True - ): - Assert.eq(weight_buffer_ref.numel() == weight_buffer_frozen.numel()) - for grad_buffer_ref, grad_buffer_frozen, frozen_parameter_count in zip( - model_ref._grad_buffers, model_frozen._grad_buffers, frozen_parameter_counts, strict=True - ): - Assert.eq(grad_buffer_ref.numel() - grad_buffer_frozen.numel() == frozen_parameter_count) + # Make sure each layer has its own buffer so the check below works. + Assert.eq( + num_stages := len(model_ref.base_model.layers), + len(model_frozen.base_model.layers), + len(model_ref.stages), + len(model_frozen.stages), + ) + for stage_index in range(num_stages): + # Weight buffers are the same. + Assert.eq( + model_ref._weight_buffers[model_ref._weight_buffer_indices[stage_index]].numel(), + model_frozen._weight_buffers[model_frozen._weight_buffer_indices[stage_index]].numel(), + ) + # Weight buffers exclude frozen weights. + Assert.eq( + model_ref._grad_buffers[model_ref._grad_buffer_indices[stage_index]].numel() + - model_frozen._grad_buffers[model_frozen._grad_buffer_indices[stage_index]].numel(), + frozen_parameter_counts[stage_index], + ) for shard_name, shard_frozen_count in zip( model_ref._shard_names, [0] + [sum(frozen_parameter_counts)] * (len(model_ref._all_shard_names) - 1), strict=True, ): + # Same with shards. Assert.eq( - model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), shard_frozen_count + model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), + shard_frozen_count, + msg=shard_name, ) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 5c3ecd8a2..306beadf8 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -67,7 +67,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon sub_configs={ ("init", None): get_config(), # Saved gradient include the gradient scaling by 2**16 (default initial value) - (None, "fw"): get_config(1e-3, 3e-4), + (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), (None, "bias"): get_config(3e-3, 1e-4, scale=2**16), (None, "gradient"): get_config(3e-3, 5e-5, scale=2**16), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index abd3d4bad..55ac4ae74 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -11,20 +11,15 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.gpt.config import ( - DiffusionDreamGPTHuggingfaceCheckpointFormat, - DiffusionLlamaGPTHuggingfaceCheckpointFormat, - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.config import ( - AprielSSMHHybridHuggingfaceCheckpointFormat, - AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, - LLambaHuggingfaceCheckpointFormat, +from fast_llm.models.gpt.conversion.config import ( + AprielHybridSSMCheckpointFormat, + DiffusionDreamCheckpointFormat, + DiffusionLlamaCheckpointFormat, + LlamaCheckpointFormat, + MistralCheckpointFormat, + MixtralCheckpointFormat, + MTPLlamaCheckpointFormat, + Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE @@ -175,9 +170,9 @@ def _update_and_add_testing_config( # Needed to match Megatron (init_1 / (2 * num_layers) ** 0.5) init_2 = {"initialization": {"type": "normal", "std": 2**-6.5}} -MODEL_CONFIGS["gpt2"] = ModelTestingConfig( +MODEL_CONFIGS["gpt_2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). - name="gpt2", + name="gpt_2", model_type="gpt", config_dict={ "run": { @@ -197,22 +192,28 @@ def _update_and_add_testing_config( "embeddings_layer": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, + "hidden_size": 256, "num_position_embeddings": 512, "vocab_size": MODEL_TEST_VOCAB_SIZE, }, - "transformer": { - "mixer": { - "query_layer": {"weight": init_1}, - "key_layer": {"weight": init_1}, - "value_layer": {"weight": init_1}, - "dense_layer": {"weight": init_2}, - "heads": 8, - "head_groups": 8, - "head_size": 32, + "decoder": { + "block": { + "mixer": { + "query_layer": {"weight": init_1}, + "key_layer": {"weight": init_1}, + "value_layer": {"weight": init_1}, + "dense_layer": {"weight": init_2}, + "heads": 8, + "head_groups": 8, + "head_size": 32, + }, + "mlp": { + "layer_1": {"weight": init_1}, + "layer_2": {"weight": init_2}, + "intermediate_size": 1024, + }, }, - "mlp": {"layer_1": {"weight": init_1}, "layer_2": {"weight": init_2}, "intermediate_size": 1024}, - "num_layers": 2, - "hidden_size": 256, + "num_blocks": 2, }, "output_layer": {"output_weight": init_1}, }, @@ -288,7 +289,8 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + # TODO: PP checkpoint failing for tied weights. + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -297,10 +299,10 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests MQA. - "gpt2", + "gpt_2", "starcoder", updates={ - ("model", "base_model", "transformer", "mixer", "head_groups"): 1, + ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 1, }, megatron_args=["--group-query-attention"], checkpoint_format=None, @@ -316,11 +318,11 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests intermediate between gpt2 and llama, closest converter to gpt2. - "gpt2", - "starcoder2", + "gpt_2", + "starcoder_2", updates={ - ("model", "base_model", "transformer", "mixer", "head_groups"): 4, - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "default", + ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 4, + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "default", ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, }, megatron_args=[ @@ -329,7 +331,7 @@ def _update_and_add_testing_config( "--use-rotary-position-embeddings", "--no-position-embedding", ], - checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, + checkpoint_format=None, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -343,14 +345,14 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Main tested model. - "starcoder2", + "starcoder_2", "llama", updates={ - ("model", "base_model", "transformer", "mixer", "add_linear_biases"): False, - ("model", "base_model", "transformer", "mlp", "gated"): True, - ("model", "base_model", "transformer", "mlp", "activation"): "silu", - ("model", "base_model", "transformer", "mlp", "add_linear_biases"): False, - ("model", "base_model", "transformer", "normalization", "type"): "rms_norm", + ("model", "base_model", "decoder", "block", "mixer", "add_linear_biases"): False, + ("model", "base_model", "decoder", "block", "mlp", "gated"): True, + ("model", "base_model", "decoder", "block", "mlp", "activation"): "silu", + ("model", "base_model", "decoder", "block", "mlp", "add_linear_biases"): False, + ("model", "base_model", "decoder", "block", "normalization", "type"): "rms_norm", ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", ("model", "base_model", "output_layer", "tied_weight"): False, }, @@ -361,7 +363,7 @@ def _update_and_add_testing_config( "--ffn-hidden-size=1024", "--untie-embeddings-and-output-weights", ], - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, @@ -376,13 +378,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", - "llama3", + "llama_3", updates={ - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "llama3", + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "llama3", }, # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, @@ -398,11 +400,11 @@ def _update_and_add_testing_config( "llama", "llama_yarn", updates={ - ("model", "base_model", "transformer", "mixer", "rotary", "type"): "yarn", + ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "yarn", }, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=LlamaCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, @@ -420,7 +422,7 @@ def _update_and_add_testing_config( updates={}, # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, - checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=DiffusionLlamaCheckpointFormat, # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ @@ -436,13 +438,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", - "llama_mtp", + "mtp_llama", updates={ - ("model", "base_model", "output_layer", "prediction_heads"): 4, + ("model", "base_model", "output_layer", "prediction_heads"): 2, }, # Megatron doesn't support multi-token prediction. megatron_args=None, - checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, + checkpoint_format=MTPLlamaCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -458,14 +460,14 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests partial linear biases, Qwen2 converter. "llama", - "qwen2", + "qwen_2", # TODO: replace updates={ - ("model", "base_model", "transformer", "add_linear_biases"): "only_attn_qkv", + ("model", "base_model", "decoder", "block", "add_linear_biases"): "only_attn_qkv", }, # Megatron doesn't support per sub layer biases. megatron_args=None, - checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, + checkpoint_format=Qwen2CheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.broken, @@ -479,13 +481,13 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests diffusion dream converter. - "qwen2", + "qwen_2", "dream", # TODO: replace only_attn_qkv updates={}, # Megatron doesn't support per sub layer biases. megatron_args=None, - checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, + checkpoint_format=DiffusionDreamCheckpointFormat, # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ @@ -503,11 +505,11 @@ def _update_and_add_testing_config( "llama", "mistral", updates={ - ("model", "base_model", "transformer", "mixer", "window_size"): 128, + ("model", "base_model", "decoder", "block", "mixer", "window_size"): 128, }, # Megatron doesn't support sliding windows. megatron_args=None, - checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, + checkpoint_format=MistralCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -524,16 +526,16 @@ def _update_and_add_testing_config( "llama", "mixtral", updates={ - ("model", "base_model", "transformer", "mlp", "type"): "moe", - ("model", "base_model", "transformer", "mlp", "router", "weight"): init_1, - ("model", "base_model", "transformer", "mlp", "experts"): 4, - ("model", "base_model", "transformer", "mlp", "experts_per_token"): 4, + ("model", "base_model", "decoder", "block", "mlp", "type"): "moe", + ("model", "base_model", "decoder", "block", "mlp", "router", "weight"): init_1, + ("model", "base_model", "decoder", "block", "mlp", "experts"): 4, + ("model", "base_model", "decoder", "block", "mlp", "experts_per_token"): 4, }, megatron_args=[ "--num-experts=4", "--moe-router-topk=4", ], - checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, + checkpoint_format=MixtralCheckpointFormat, # TODO: New base image broke mixtral groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, @@ -546,29 +548,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] + + _update_and_add_testing_config( # Tests hybrid Mamba, llamba converter. "llama", - "llamba", - model_type="hybrid_ssm", + "hybrid_mamba", updates={ - ("model", "base_model", "ssm"): { - "type": "mamba", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "mamba", + "d_inner": 512, + "state_size": 16, + "dt_rank": 16, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m']", }, megatron_args=None, - checkpoint_format=LLambaHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -581,25 +595,35 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests hybrid Mamba 2. "llama", - "hybrid_mamba2", - model_type="hybrid_ssm", + "hybrid_mamba_2", updates={ - ("model", "base_model", "ssm"): { - "type": "mamba_2", - "d_inner": 512, - "state_size": 8, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m2": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "mamba_2", + "dt_layer": {"bias": {"enabled": True}}, + "d_inner": 512, + "state_size": 8, + "dt_rank": 16, + "d_xb": 256, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m2"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m2']", }, megatron_args=None, - checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, @@ -616,26 +640,35 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests hybrid discrete Mamba 2. "llama", - "hybrid_discrete_mamba2", - model_type="hybrid_ssm", + "hybrid_discrete_mamba_2", updates={ - ("model", "base_model", "ssm"): { - "type": "discrete_mamba_2", - "d_inner": 512, - "state_size": 8, - "n_qk_heads": 8, - "n_v_heads": 16, - "chunk_size": 32, - "add_linear_biases": False, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "m2d": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "discrete_mamba_2", + "d_inner": 512, + "state_size": 8, + "n_qk_heads": 8, + "n_v_heads": 16, + "chunk_size": 32, + "add_linear_biases": False, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "m2d"], }, - ("model", "base_model", "hybrid_block_layout"): "['t','m2d']", }, megatron_args=None, - checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, + checkpoint_format=AprielHybridSSMCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement diff --git a/tests/utils/save_load_configs.py b/tests/utils/save_load_configs.py index f5a15020e..3e7cbf10f 100644 --- a/tests/utils/save_load_configs.py +++ b/tests/utils/save_load_configs.py @@ -18,13 +18,32 @@ class DistributedSaveLoadConfig: num_gpus: int = 2 def resolve(self, base_path: pathlib.Path, model_testing_config: ModelTestingConfig) -> typing.Self: + if model_testing_config.checkpoint_format is None: + format = { + "distributed": do_get_convert_path( + DistributedCheckpointFormat.name, FastLLMCheckpointFormat.name, base_path=pathlib.Path() + ), + "fast_llm": do_get_convert_path( + FastLLMCheckpointFormat.name, DistributedCheckpointFormat.name, base_path=pathlib.Path() + ), + } + else: + format = { + "checkpoint_format": model_testing_config.checkpoint_format.name, + "distributed": do_get_convert_path( + DistributedCheckpointFormat.name, + model_testing_config.checkpoint_format.name, + base_path=pathlib.Path(), + ), + "fast_llm": do_get_convert_path( + FastLLMCheckpointFormat.name, model_testing_config.checkpoint_format.name, base_path=pathlib.Path() + ), + } return dataclasses.replace( self, - load_path=base_path - / str(self.load_path).format(checkpoint_format=model_testing_config.checkpoint_format.name), - load_format=self.load_format.format(checkpoint_format=model_testing_config.checkpoint_format.name), - save_path=base_path - / str(self.save_path).format(checkpoint_format=model_testing_config.checkpoint_format.name), + load_path=base_path / str(self.load_path).format(**format), + load_format=self.load_format.format(**format), + save_path=base_path / str(self.save_path).format(**format), ) @property @@ -58,11 +77,11 @@ def get_convert_path(run_test_script_base_path): for pretrained_format, pretrained_path in ( ( DistributedCheckpointFormat.name, - do_get_convert_path(DistributedCheckpointFormat.name, "{checkpoint_format}", base_path=pathlib.Path()), + pathlib.Path("{distributed}"), ), ( FastLLMCheckpointFormat.name, - do_get_convert_path(FastLLMCheckpointFormat.name, "{checkpoint_format}", base_path=pathlib.Path()), + pathlib.Path("{fast_llm}"), ), ( "{checkpoint_format}",