diff --git a/fast_llm/config.py b/fast_llm/config.py index 7e55c03bd..d8ae570c5 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -9,7 +9,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log +from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value logger = logging.getLogger(__name__) @@ -663,15 +663,7 @@ def from_dict( if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) for keys, value in update.items(): - if isinstance(keys, str): - default[keys] = value - else: - dict_to_update = default - for key in keys[:-1]: - if key not in dict_to_update: - dict_to_update[key] = {} - dict_to_update = dict_to_update[key] - dict_to_update[keys[-1]] = value + set_nested_dict_value(default, keys, value) return cls._from_dict(default, strict) @@ -802,12 +794,21 @@ def _from_dict_dict(cls, value, type_, strict: bool): return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()} @classmethod - def _handle_renamed_field(cls, default: dict[str, typing.Any], old_name: str, new_name: str): + def _handle_renamed_field( + cls, + default: dict[str, typing.Any], + old_name: str | tuple[str, ...], + new_name: str | tuple[str, ...], + fn: typing.Callable | None = None, + ): if old_name in default: warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.") - default[new_name] = default.pop(old_name) + value = pop_nested_dict_value(default, old_name) + if fn is not None: + value = fn(value) + set_nested_dict_value(default, new_name, value) - def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Callable] = ValueError): + def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) other_dict = other._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) @@ -824,7 +825,7 @@ def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Ca log( f"Config diff:\n " + "\n ".join( - f"{''.join(key)}`: `{self_value}` != `{other_value}`" + f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" for key, (self_value, other_value) in diff.items() ), log_fn=log_fn, diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 6a9ba863f..8d2d91128 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -24,7 +24,7 @@ def get_architecture(self): def compare_architecture( self, model_config: "BaseModelArchitectureConfig", - log_fn: typing.Union[BaseException, typing.Callable] = ValueError, + log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError, ): return self.get_architecture().compare(model_config.get_architecture(), log_fn) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 538002a52..733b38332 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -19,7 +19,7 @@ from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ @dataclasses.dataclass class ParamConverter: fast_llm_name: tuple[str, ...] | None - export_name: str | None + export_name: tuple[str, ...] | str | None def export_param(self, fast_llm_value): return fast_llm_value @@ -203,7 +203,7 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa ) if converter.export_name is not None: - exported_config[converter.export_name] = value + set_nested_dict_value(exported_config, converter.export_name, value) return exported_config # Noqa @@ -213,11 +213,11 @@ def _import_config( ) -> BaseModelArchitectureConfig: # noqa kwargs = {} for converter in cls._get_config_converters(): - value = converter.import_param( - None - if converter.export_name is None or converter.export_name not in config - else config[converter.export_name] - ) + try: + value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name) + except KeyError: + value = None + value = converter.import_param(value) if converter.fast_llm_name is not None: kwargs[converter.fast_llm_name] = value diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1739901a4..07dad39fc 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -248,7 +248,7 @@ def is_main_rank(): return DistributedConfig.default_rank == _MAIN_RANK -def log_main_rank(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger.info, join: str = ", "): +def log_main_rank(*message, log_fn: typing.Union[type[BaseException], typing.Callable] = logger.info, join: str = ", "): if is_main_rank(): log(*message, log_fn=log_fn, join=join) diff --git a/fast_llm/functional/rotary.py b/fast_llm/functional/rotary.py index c5844aa06..6af87f99c 100644 --- a/fast_llm/functional/rotary.py +++ b/fast_llm/functional/rotary.py @@ -1,5 +1,3 @@ -import math - import torch from fast_llm.utils import div @@ -13,31 +11,6 @@ def convert_rotary_real_to_complex(tensor: torch.Tensor, kv_channels: int, dim: return tensor.unflatten(dim, (-1, 2, div(kv_channels, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) -def get_rotary_frequencies( - sequence_length, - kv_channels, - scale=-math.log(10000), - *, - complex_format: bool = True, - device="cuda", -): - # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) - # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, - # where n is the position in the sequence. - # We preform the calculation in high precision because it matters for rotary embeddings. - angles = torch.outer( - torch.arange(sequence_length, device=device, dtype=torch.float64), - torch.exp(scale * torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64)), - ) - frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 - ).contiguous() - return frequencies - - def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: """ Apply rotary embeddings to a tensor: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9e54ea35f..ae0a96e5f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -60,7 +60,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self): if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.use_rotary_embeddings + self.use_position_embeddings = not self.transformer.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace): diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 05e36054b..42ed61db7 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -76,8 +76,6 @@ def __init__( self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - self._triton_rotary = self._config.triton_rotary - init_method_qkv = init_normal_( std=self._config.init_method_std_qkv, min_val=self._config.init_method_min_qkv, @@ -300,7 +298,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict): key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._config.use_rotary_position_embeddings: + if self._config.rotary.enabled: if self._debug_transformer: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( @@ -309,7 +307,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict): self._KV_DIMS, kwargs, ) - rotary_fn = triton_rotary_autograd_ if self._triton_rotary else apply_rotary_embeddings + rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 97fa112b3..3e688d1c3 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,6 +1,7 @@ import enum import logging import math +import typing import warnings from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none @@ -75,6 +76,72 @@ class TransformerLossNames: router_z_loss = "router_z_loss" +class RotaryEmbeddingType(str, enum.Enum): + none = "none" + default = "default" + llama3 = "llama3" + + +@config_class() +class RotaryArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.none, + desc="The type of rotary embedding to use. Choices: none, default, llama3.", + hint=FieldHint.feature, + ) + theta: float = Field( + default=10000, + desc="Scale for the rotary positional embeddings", + hint=FieldHint.feature, + ) + # TODO: Make a backup implementation that doesn't affect the layout. + triton: bool = Field( + default=True, + desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", + hint=FieldHint.deprecated, + ) + # TODO: These are not really architecture parameters, but we want to import them from huggingface. + scale_factor: float = Field(default=8.0, desc="Scaling factor for llama3-type scaling.", hint=FieldHint.feature) + low_frequency_factor: float = Field( + default=1.0, desc="Low frequency factor for llama3-type scaling.", hint=FieldHint.feature + ) + high_frequency_factor: float = Field( + default=4.0, desc="High frequency factor for llama3-type scaling.", hint=FieldHint.feature + ) + original_context_length: int = Field( + default=8192, desc="Original context length for llama3-type scaling.", hint=FieldHint.feature + ) + + @property + def enabled(self): + return self.type != RotaryEmbeddingType.none + + @property + def complex_format(self): + # TODO: Make a backup implementation that doesn't affect the layout. + return self.enabled and not self.triton + + def _validate(self): + # These happen during conversion. + if self.scale_factor is None: + self.scale_factor = 8.0 + if self.low_frequency_factor is None: + self.low_frequency_factor = 1.0 + if self.high_frequency_factor is None: + self.high_frequency_factor = 4.0 + if self.original_context_length is None: + self.original_context_length = 8192 + super()._validate() + if self.triton and not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + + +@config_class() +class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig): + pass + + @config_class() class TransformerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -119,12 +186,9 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - use_rotary_embeddings: bool = Field( - default=False, desc="Enable rotary positional embeddings.", hint=FieldHint.feature - ) - rotary_embedding_scale: float = Field( - default=-math.log(10000), - desc="Scale for the rotary positional embeddings. Default: -math.log(10000) = -9.210", + rotary: RotaryArchitectureConfig = Field( + default_factory=RotaryArchitectureConfig, + desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.feature) @@ -133,11 +197,6 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.core, ) - triton_rotary: bool = Field( - default=True, - desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", - hint=FieldHint.deprecated, - ) num_experts: int = Field( default=1, desc="Number of MLP experts in a Mixture of Expert (MoE) model", @@ -185,6 +244,24 @@ def _validate(self): Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) Assert.multiple(self.num_attention_heads, self.head_groups) + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field( + default, + "use_rotary_embeddings", + ("rotary", "type"), + lambda x: RotaryEmbeddingType.default if x else RotaryEmbeddingType.none, + ) + cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) + cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) + return super()._from_dict(default, strict, flat) + def setup_tensor_space(self, tensor_space: TensorSpace): tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -245,24 +322,11 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) - @property - def complex_rotary_embeddings(self): - return self.use_rotary_position_embeddings and not self.triton_rotary - - @property - def rotary_position_embedding_scale(self): - # TODO: Set through rotary theta instead. - return self.rotary_embedding_scale if self.use_rotary_position_embeddings else None - - @property - def use_rotary_position_embeddings(self): - # TODO: Set through rotary theta instead. - return self.use_rotary_embeddings - @config_class() class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) + rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( @@ -492,8 +556,6 @@ def _validate(self): if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() - if self.triton_rotary and not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index da093bb43..b8c838190 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -1,15 +1,70 @@ import logging +import math import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.functional.rotary import get_rotary_frequencies -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.functional.rotary import convert_rotary_complex_to_real +from fast_llm.layers.transformer.config import ( + RotaryConfig, + RotaryEmbeddingType, + TransformerConfig, + TransformerDimNames, + TransformerKwargs, +) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) +def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> torch.Tensor: + """ + Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 + """ + low_frequency_wavelength = config.original_context_length / config.low_frequency_factor + high_frequency_wavelength = config.original_context_length / config.high_frequency_factor + new_frequencies = [] + for frequency in frequencies: + wavelength = 2 * math.pi / frequency + if wavelength < high_frequency_wavelength: + new_frequencies.append(frequency) + elif wavelength > low_frequency_wavelength: + new_frequencies.append(frequency / config.scale_factor) + else: + assert low_frequency_wavelength != high_frequency_wavelength + smooth = (config.original_context_length / wavelength - config.low_frequency_factor) / ( + config.high_frequency_factor - config.low_frequency_factor + ) + new_frequencies.append((1 - smooth) * frequency / config.scale_factor + smooth * frequency) + return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device) + + +def get_rotary_frequencies( + config: RotaryConfig, + sequence_length, + kv_channels, + *, + device="cuda", +): + # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # where n is the position in the sequence. + # We preform the calculation in high precision because it matters for rotary embeddings. + positions = torch.arange(sequence_length, device=device, dtype=torch.float64) + frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + # Apply scaling + if config.type == RotaryEmbeddingType.llama3: + frequencies = apply_llama3_scaling(config, frequencies) + angles = torch.outer(positions, frequencies) + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + return frequencies + + class RotaryEmbeddingPreprocessor: _scalar_dim: TensorDim _kv_channels_dim: TensorDim @@ -20,11 +75,11 @@ class RotaryEmbeddingPreprocessor: def __init__( self, - config: TransformerConfig, + config: RotaryConfig, tensor_space: TensorSpace, ): self._config = config - assert self._config.use_rotary_position_embeddings + assert self._config.enabled self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) @@ -36,10 +91,9 @@ def create_tensors(self, sequence_length: int): self._tensor_cache_max_sequence_length = sequence_length self._rotary_embedding_frequencies = get_rotary_frequencies( + self._config, sequence_length, - self._config.kv_channels, - self._config.rotary_position_embedding_scale, - complex_format=self._config.complex_rotary_embeddings, + self._kv_channels_dim.global_size, device=self._tensor_space.distributed.device, ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index eea98b8b6..578892ae5 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -1,5 +1,4 @@ import abc -import math import typing import torch @@ -21,7 +20,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RoutingType +from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTModelConfig, @@ -45,7 +44,7 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.complex_rotary_embeddings: + if self._config.transformer.rotary.complex_format: query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0) return (query,) @@ -53,7 +52,7 @@ def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.complex_rotary_embeddings: + if self._config.transformer.rotary.complex_format: query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0) return (query,) @@ -67,7 +66,7 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.transformer.complex_rotary_embeddings: + if self._config.transformer.rotary.complex_format: key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0) return key, value @@ -75,7 +74,7 @@ def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.transformer.complex_rotary_embeddings: + if self._config.transformer.rotary.complex_format: key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) @@ -114,10 +113,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter(("use_position_embeddings",), None, False), - ConstantImportParamConverter(("transformer", "use_rotary_embeddings"), None, True), - MappedConfigParamConverter( - ("transformer", "rotary_embedding_scale"), "rope_theta", lambda x: -math.log(x), lambda x: math.exp(-x) - ), + ParamConverter(("transformer", "rotary", "theta"), "rope_theta"), MappedConfigParamConverter( ("transformer", "activation_type"), "hidden_act", @@ -219,6 +215,7 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler) def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]), + ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), ConstantImportParamConverter(("transformer", "normalization", "type"), None, NormalizationType.layer_norm), ParamConverter(("transformer", "normalization", "epsilon"), "norm_epsilon"), ConstantImportParamConverter(("transformer", "gated"), None, False), @@ -248,6 +245,28 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] +def export_rotary_scaling_type(fast_llm_value: RotaryEmbeddingType): + match fast_llm_value: + case RotaryEmbeddingType.default: + return "default" + case RotaryEmbeddingType.llama3: + return "llama3" + case _: + raise ValueError(f"Unsupported rotary scaling type: {fast_llm_value}") + + +def import_rotary_scaling_type(export_value): + if export_value is None: + return RotaryEmbeddingType.default + match export_value: + case "default": + return RotaryEmbeddingType.default + case "llama3": + return RotaryEmbeddingType.llama3 + case _: + raise ValueError(f"Unsupported rotary scaling type: {export_value}") + + class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat @@ -258,7 +277,28 @@ def _create_config_converters(cls) -> list[ParamConverter]: # TODO: Llama supports biases ConstantExportParamConverter(None, "attention_bias", False), ConstantExportParamConverter(None, "mlp_bias", False), - ConstantExportParamConverter(None, "rope_scaling", None), + MappedConfigParamConverter( + ("transformer", "rotary", "type"), + ("rope_scaling", "rope_type"), + import_rotary_scaling_type, + export_rotary_scaling_type, + ), + ParamConverter( + ("transformer", "rotary", "scale_factor"), + ("rope_scaling", "factor"), + ), + ParamConverter( + ("transformer", "rotary", "low_frequency_factor"), + ("rope_scaling", "low_freq_factor"), + ), + ParamConverter( + ("transformer", "rotary", "high_frequency_factor"), + ("rope_scaling", "high_freq_factor"), + ), + ParamConverter( + ("transformer", "rotary", "original_context_length"), + ("rope_scaling", "original_max_position_embeddings"), + ), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): @@ -286,6 +326,7 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(None, "architectures", ["MistralForCausalLM"]), + ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), IgnoreImportParamConverter(None, "sliding_window", None), ] @@ -310,6 +351,7 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(None, "architectures", ["MixtralForCausalLM"]), + ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), ConstantImportParamConverter(("transformer", "expert_routing_type"), None, RoutingType.topk), ParamConverter(("transformer", "num_experts"), "num_local_experts"), ParamConverter(("transformer", "num_experts_per_token"), "num_experts_per_tok"), diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 0ac400b94..a5ef15a64 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -90,7 +90,7 @@ def _init_attention_megatron( else: raise NotImplementedError(meta.tensor_name) - if config.use_rotary_position_embeddings and config.complex_rotary_embeddings: + if config.rotary.enabled and config.rotary.complex_format: from fast_llm.functional.rotary import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index baf1f5c4d..90f2f883a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -55,9 +55,9 @@ def __init__( param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) - if self._config.transformer.use_rotary_position_embeddings: + if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - self._config.transformer, self._tensor_space + self._config.transformer.rotary, self._tensor_space ) if not self._use_flash_attention: self._backup_attention_preprocessor = BackupAttentionPreprocessor( @@ -172,7 +172,7 @@ def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) ) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess_meta(kwargs) - if self._config.transformer.use_rotary_position_embeddings: + if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess_meta(kwargs) @@ -211,7 +211,7 @@ def preprocess( if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.create_tensors(sequence_length) - if self._config.transformer.use_rotary_position_embeddings: + if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.create_tensors(sequence_length) if not self._use_flash_attention: self._backup_attention_preprocessor.create_tensors(sequence_length) @@ -244,7 +244,7 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) - if self._config.transformer.use_rotary_position_embeddings: + if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess(kwargs) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 5ae1c5d0d..2f8981e87 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -206,12 +206,13 @@ def __getitem__(self, key): return super().__getitem__(key)() -def log(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger.info, join: str = ", "): +def log(*message, log_fn: typing.Union[type[BaseException], typing.Callable] = logger.info, join: str = ", "): message = join.join([str(m() if callable(m) else m) for m in message]) - if isinstance(log_fn, BaseException): - raise log_fn(message) + logged = log_fn(message) + if isinstance(logged, BaseException): + raise logged else: - return log_fn(message) + return logged def normalize_probabilities(p: list[float]) -> list[float]: @@ -222,3 +223,31 @@ def normalize_probabilities(p: list[float]) -> list[float]: p_sum = p.sum() Assert.gt(p_sum, 0) return (p / p_sum).tolist() + + +def set_nested_dict_value(d: dict, keys: str | tuple[str, ...], value): + if isinstance(keys, str): + d[keys] = value + else: + for key in keys[:-1]: + d = d.setdefault(key, {}) + assert isinstance(d, dict) + d[keys[-1]] = value + + +def get_nested_dict_value(d: dict, keys: str | tuple[str, ...]): + if isinstance(keys, str): + return d[keys] + else: + for key in keys: + d = d[key] + return d + + +def pop_nested_dict_value(d: dict, keys: str | tuple[str, ...]): + if isinstance(keys, str): + return d.pop(keys) + else: + for key in keys[:-1]: + d = d[key] + return d.pop(keys[-1]) diff --git a/tests/common.py b/tests/common.py index a944dd826..b96a22040 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,6 +12,7 @@ from fast_llm.data.gpt.memmap import GPTMemmapDataset from fast_llm.models.gpt.config import ( + LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -101,7 +102,7 @@ CONFIG_SC2_FAST_LLM = CONFIG_BASE_FAST_LLM + [ "model.base_model.transformer.head_groups=4", - "model.base_model.transformer.use_rotary_embeddings=True", + "model.base_model.transformer.rotary.type=default", ] CONFIG_SC2_MEGATRON = CONFIG_SC1_MEGATRON + [ "--num-query-groups=4", @@ -137,6 +138,17 @@ ] CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA3_MEGATRON = None # Megatron does not support Llama3-style Rotary Embeddings +CONFIG_LLAMA3_FAST_LLM = CONFIG_SC2_FAST_LLM + [ + "model.base_model.transformer.gated=True", + "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.add_linear_biases=False", + "model.base_model.transformer.normalization.type=rms_norm", + "model.base_model.transformer.rotary.type=llama3", + "model.base_model.tie_word_embeddings=False", +] +CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] + _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), "sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), @@ -161,6 +173,13 @@ CONFIG_MIXTRAL_COMMON, MixtralGPTHuggingfaceCheckpointFormat, ), + "llama3": ( + "gpt", + CONFIG_LLAMA3_FAST_LLM, + CONFIG_LLAMA3_MEGATRON, + CONFIG_LLAMA3_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), } diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 955e8b9af..db14a813e 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -240,8 +240,8 @@ def _compare_configs(config_ref, config_test): @pytest.mark.depends(on=["test_converted_distributed"]) def test_load_pretrained_distributed_checkpoint(): - config = TEST_ARCHITECTURE_CONFIG_CLS.from_dict( - yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r")), strict=False + config = TEST_MODEL_CONFIG_CLS.from_dict( + yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r"))["model"], strict=False ) pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, @@ -250,7 +250,7 @@ def test_load_pretrained_distributed_checkpoint(): load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) - _compare_configs(config, model._base_model_config) + _compare_configs(config.base_model, model._base_model_config) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._state_shard.device) )["state_shard"] diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index cd79624d7..c23c321a0 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -6,7 +6,6 @@ apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, - get_rotary_frequencies, ) from fast_llm.functional.triton.adam import triton_adam from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward @@ -23,6 +22,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType +from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies from fast_llm.utils import Assert, rms_diff from tests.common import requires_cuda @@ -82,13 +83,26 @@ def test_triton_add(): def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): assert TritonConfig.TRITON_ENABLED x = torch.randn(batch_size, sequence_length, num_heads, kv_channels, dtype=torch.bfloat16, device="cuda") - y1 = apply_rotary_embeddings(x, get_rotary_frequencies(sequence_length, kv_channels, device="cuda")) - convert_rotary_complex_to_real(x, kv_channels, 3) + + y1 = apply_rotary_embeddings( + x, + get_rotary_frequencies( + RotaryConfig(type=RotaryEmbeddingType.default, triton=False), + sequence_length, + kv_channels, + device="cuda", + ), + ) y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - get_rotary_frequencies(sequence_length, kv_channels, device="cuda", complex_format=False), + get_rotary_frequencies( + RotaryConfig(type=RotaryEmbeddingType.default, triton=True), + sequence_length, + kv_channels, + device="cuda", + ), ), kv_channels, 3,