Skip to content
4 changes: 4 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "llama"

class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "qwen2"


class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "mistral"
Expand Down Expand Up @@ -98,6 +101,7 @@ class GPTModelConfig(FastLLMModelConfig):
AutoGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
)
Expand Down
112 changes: 99 additions & 13 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
import dataclasses
import logging
import typing

import torch

from fast_llm.config import DEFAULT
from fast_llm.config import DEFAULT, MISSING
from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import (
AutoStateDictCheckpointHandler,
Expand All @@ -23,11 +24,12 @@
from fast_llm.functional.config import ActivationType
from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex
from fast_llm.layers.common.config import NormalizationType
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig
from fast_llm.models.gpt.config import (
GPTArchitectureConfig,
GPTModelConfig,
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
Expand All @@ -39,6 +41,8 @@
if typing.TYPE_CHECKING:
pass

logger = logging.getLogger(__name__)


class QueryWeightConverter(WeightConverter):
# Hf uses the real format for rotary embeddings.
Expand Down Expand Up @@ -156,11 +160,14 @@ def _create_config_converters(cls) -> list[ParamConverter]:
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
pass

def _create_weight_converters(self) -> list[WeightConverter]:

def _create_weight_converters(
self,
) -> list[WeightConverter]:
converters = []
num_layers = self._model.config.base_model.transformer.num_layers
norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer

# Embedding and output
if self._model.config.base_model.tie_word_embeddings:
Expand All @@ -180,17 +187,19 @@ def _create_weight_converters(self) -> list[WeightConverter]:
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.query",
f"model.layers.{i}.self_attn.q_proj",
linear_bias,
transformer_config.add_attn_qkv_bias,
QueryWeightConverter,
)
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.key_value",
(f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"),
linear_bias,
transformer_config.add_attn_qkv_bias,
KeyValueWeightConverter,
)
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", linear_bias
f"layers.{i+1}.self_attn.dense",
f"model.layers.{i}.self_attn.o_proj",
transformer_config.add_attn_dense_bias,
)

# Norm
Expand Down Expand Up @@ -256,13 +265,16 @@ def _create_config_converters(cls) -> list[ParamConverter]:
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", linear_bias, MLPLayer2Converter
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.c_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]

Expand Down Expand Up @@ -352,18 +364,91 @@ def _create_config_converters(cls) -> list[ParamConverter]:
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


@dataclasses.dataclass
class IgnoreImportQwen2SlidingWindowParamsConverter(ParamConverter):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigximik this is fine, but can you please add a todo here that says that this is a temporary hack until we can load these params from the config?

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 0)
self.export_names = (("use_sliding_window",), ("sliding_window",), ("max_window_layers",))

def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
return (MISSING, MISSING, MISSING)

def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
# Default value for use_sliding_window in Qwen2 HF config is False
if export_values[0] != MISSING and export_values[0] == True:
logger.warning(
f"The configuration parameters `{self.export_names[0]}={export_values[0]}`,"
f" `{self.export_names[1]}={export_values[1]}`, `{self.export_names[2]}={export_values[2]}`"
f" are ignored during conversion."
f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration."
)
return ()


class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
),
RenameParamConverter(
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
),
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
),
RopeScalingParamConverter(
fast_llm_names=(
("transformer", "rotary", "type"),
("transformer", "rotary", "scale_factor"),
("transformer", "rotary", "low_frequency_factor"),
("transformer", "rotary", "high_frequency_factor"),
("transformer", "rotary", "original_context_length"),
("transformer", "rotary", "attention_factor"),
("transformer", "rotary", "beta_fast"),
("transformer", "rotary", "beta_slow"),
),
export_names=(("rope_scaling",),),
),
IgnoreImportQwen2SlidingWindowParamsConverter(),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
linear_bias,
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
linear_bias,
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]
Expand Down Expand Up @@ -439,6 +524,7 @@ class AutoGPTHuggingfaceCheckpointHandler(
handler_map = {
Starcoder2GPTHuggingfaceCheckpointFormat.name: Starcoder2HuggingfaceCheckpointHandler,
LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler,
Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler,
MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler,
MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler,
}
20 changes: 20 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.models.gpt.config import (
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
Expand Down Expand Up @@ -155,6 +156,18 @@
]
CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"]

# Megatron does not support per sub layer biases
CONFIG_QWEN2_MEGATRON = None
CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [
"model.base_model.transformer.gated=True",
"model.base_model.transformer.activation_type=silu",
"model.base_model.transformer.add_linear_biases=only_attn_qkv",
"model.base_model.transformer.normalization.type=rms_norm",
"model.base_model.transformer.ffn_hidden_size=1024",
"model.base_model.tie_word_embeddings=False",
]
CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"]

# Yarn-style Rotary Embeddings
CONFIG_LLAMA_YARN_MEGATRON = None
CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [
Expand Down Expand Up @@ -202,6 +215,13 @@
CONFIG_LLAMA3_COMMON,
LlamaGPTHuggingfaceCheckpointFormat,
),
"qwen2": (
"gpt",
CONFIG_QWEN2_FAST_LLM,
CONFIG_QWEN2_MEGATRON,
CONFIG_QWEN2_COMMON,
Qwen2GPTHuggingfaceCheckpointFormat,
),
"llama-yarn": (
"gpt",
CONFIG_LLAMA_YARN_FAST_LLM,
Expand Down