diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 86c556fc7..b88f45540 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -118,6 +118,16 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + init_method_max_embed: float | None = Field( + default=None, + desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) + init_method_min_embed: float | None = Field( + default=None, + desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) cross_entropy_impl: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", @@ -169,4 +179,10 @@ def _validate(self): self.transformer.init_method_std = self.transformer.hidden_size**-0.5 if self.init_method_std_embed is None: self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 989f71118..e3859b4f2 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -53,12 +53,20 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_(std=config.init_method_std_embed), + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), - init_method=init_normal_(std=config.init_method_std_embed), + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), allow_sequence_tensor_parallel=not config.parallel_embeddings, ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 611a63309..194d6f947 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -63,7 +63,11 @@ def __init__( ) self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_(std=config.init_method_std_embed), + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), ) self._cross_entropy_impl = config.cross_entropy_impl diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index d4546d650..909780193 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -78,8 +78,16 @@ def __init__( self._triton_rotary = self._config.triton_rotary - init_method_qkv = init_normal_(std=self._config.init_method_std_qkv) - init_method_std_attn_proj = init_normal_(std=self._config.init_method_std_attn_proj) + init_method_qkv = init_normal_( + std=self._config.init_method_std_qkv, + min_val=self._config.init_method_min_qkv, + max_val=self._config.init_method_max_qkv, + ) + init_method_std_attn_proj = init_normal_( + std=self._config.init_method_std_attn_proj, + min_val=self._config.init_method_min_attn_proj, + max_val=self._config.init_method_max_attn_proj, + ) self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e6c133f14..97fa112b3 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -271,30 +271,80 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) init_method_std_qkv: float = Field( default=None, desc="Scale for the query, key and value weight initialization. Default: init_method_std", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) init_method_std_attn_proj: float = Field( default=None, desc="Scale for the attention projection weight initialization. Default: init_method_std", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) init_method_std_mlp_1: float = Field( default=None, desc="Scale for the MLP first layer weight initialization. Default: init_method_std", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) init_method_std_mlp_2: float = Field( default=None, desc="Scale for the MLP second layer weight initialization. Default: init_method_std", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) attention_dropout: float = Field( default=0.0, desc="Dropout applied to the attention intermediate states.", @@ -413,6 +463,34 @@ def _validate(self): self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: self.mlp_lr_scale = [None] + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() if self.triton_rotary and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 7c15731e7..8b06e0df8 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -62,7 +62,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), bias=False, - weight_init_method=init_normal_(std=config.init_method_std), + weight_init_method=init_normal_( + std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max + ), lr_scale=config.router_lr_scale, ) dropless_moe = config.dropless_moe diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 7df029327..6e2f9381c 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -17,8 +17,16 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name - init_method_1 = init_normal_(std=config.init_method_std_mlp_1) - init_method_2 = init_normal_(std=config.init_method_std_mlp_2) + init_method_1 = init_normal_( + std=config.init_method_std_mlp_1, + min_val=config.init_method_min_mlp_1, + max_val=config.init_method_max_mlp_1, + ) + init_method_2 = init_normal_( + std=config.init_method_std_mlp_2, + min_val=config.init_method_min_mlp_2, + max_val=config.init_method_max_mlp_2, + ) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 5787994f4..f4217f711 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -312,15 +312,23 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) init_ones_ = init_fill_(1.0) -def init_normal_(mean=0.0, std=1.0): +def init_normal_(mean=0.0, std=1.0, min_val=None, max_val=None): def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.normal_(mean, std, generator=generator) + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + return tensor.clamp_(min=min_val, max=max_val) # noqa + else: + return tensor return init_ -def init_uniform_(low=0.0, high=1.0): +def init_uniform_(low=0.0, high=1.0, min_val=None, max_val=None): def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.uniform_(low, high, generator=generator) # noqa + tensor = tensor.uniform_(low, high, generator=generator) + if min_val is not None or max_val is not None: + return tensor.clamp_(min=min_val, max=max_val) # noqa + else: + return tensor return init_