Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
init_method_max_embed: float | None = Field(
default=None,
desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).",
hint=FieldHint.feature,
)
init_method_min_embed: float | None = Field(
default=None,
desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).",
hint=FieldHint.feature,
)
cross_entropy_impl: CrossEntropyImpl = Field(
default=CrossEntropyImpl.auto,
desc="Implementation for the cross-entropy computation.",
Expand Down Expand Up @@ -169,4 +179,10 @@ def _validate(self):
self.transformer.init_method_std = self.transformer.hidden_size**-0.5
if self.init_method_std_embed is None:
self.init_method_std_embed = self.transformer.init_method_std
if self.init_method_max_embed is None:
self.init_method_max_embed = self.transformer.init_method_max
if self.init_method_min_embed is None:
self.init_method_min_embed = self.transformer.init_method_min
if self.init_method_max_embed is not None and self.init_method_min_embed is not None:
Assert.leq(self.init_method_min_embed, self.init_method_max_embed)
super()._validate()
12 changes: 10 additions & 2 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,20 @@ def __init__(

self.word_embeddings_weight = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(std=config.init_method_std_embed),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
)
if self._use_absolute_position_embeddings:
self.position_embeddings_weight = ParameterMeta.from_dims(
(tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim),
init_method=init_normal_(std=config.init_method_std_embed),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
allow_sequence_tensor_parallel=not config.parallel_embeddings,
)

Expand Down
6 changes: 5 additions & 1 deletion fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def __init__(
)
self.output_weights = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(std=config.init_method_std_embed),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
max_val=config.init_method_max_embed,
),
)

self._cross_entropy_impl = config.cross_entropy_impl
Expand Down
12 changes: 10 additions & 2 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,16 @@ def __init__(

self._triton_rotary = self._config.triton_rotary

init_method_qkv = init_normal_(std=self._config.init_method_std_qkv)
init_method_std_attn_proj = init_normal_(std=self._config.init_method_std_attn_proj)
init_method_qkv = init_normal_(
std=self._config.init_method_std_qkv,
min_val=self._config.init_method_min_qkv,
max_val=self._config.init_method_max_qkv,
)
init_method_std_attn_proj = init_normal_(
std=self._config.init_method_std_attn_proj,
min_val=self._config.init_method_min_attn_proj,
max_val=self._config.init_method_max_attn_proj,
)

self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size
self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size
Expand Down
78 changes: 78 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,30 +271,80 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max: float | None = Field(
default=None,
desc="Max value for clamping initialized weights. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min: float | None = Field(
default=None,
desc="Min value for clamping initialized weights. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_qkv: float = Field(
default=None,
desc="Scale for the query, key and value weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_qkv: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_qkv: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_attn_proj: float = Field(
default=None,
desc="Scale for the attention projection weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_attn_proj: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for attention projection. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_attn_proj: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_mlp_1: float = Field(
default=None,
desc="Scale for the MLP first layer weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_mlp_1: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_mlp_1: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')",
hint=FieldHint.optional,
)
init_method_std_mlp_2: float = Field(
default=None,
desc="Scale for the MLP second layer weight initialization. Default: init_method_std",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
init_method_max_mlp_2: float | None = Field(
default=None,
desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')",
hint=FieldHint.optional,
)
init_method_min_mlp_2: float | None = Field(
default=None,
desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')",
hint=FieldHint.optional,
)
attention_dropout: float = Field(
default=0.0,
desc="Dropout applied to the attention intermediate states.",
Expand Down Expand Up @@ -413,6 +463,34 @@ def _validate(self):
self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5
if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0:
self.mlp_lr_scale = [None]
if self.init_method_max_qkv is None:
self.init_method_max_qkv = self.init_method_max
if self.init_method_min_qkv is None:
self.init_method_min_qkv = self.init_method_min
if self.init_method_max_attn_proj is None:
self.init_method_max_attn_proj = self.init_method_max
if self.init_method_min_attn_proj is None:
self.init_method_min_attn_proj = self.init_method_min
if self.init_method_max_mlp_1 is None:
self.init_method_max_mlp_1 = self.init_method_max
if self.init_method_min_mlp_1 is None:
self.init_method_min_mlp_1 = self.init_method_min
if self.init_method_max_mlp_2 is None:
self.init_method_max_mlp_2 = self.init_method_max
if self.init_method_min_mlp_2 is None:
self.init_method_min_mlp_2 = self.init_method_min
if self.init_method_min is not None and self.init_method_max is not None:
Assert.leq(self.init_method_min, self.init_method_max)
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min, self.init_method_max)
if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None:
Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv)
if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None:
Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj)
if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None:
Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1)
if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None:
Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2)
super()._validate()
if self.triton_rotary and not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")
Expand Down
4 changes: 3 additions & 1 deletion fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
tensor_space.get_tensor_dim(TransformerDimNames.hidden),
tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts),
bias=False,
weight_init_method=init_normal_(std=config.init_method_std),
weight_init_method=init_normal_(
std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max
),
lr_scale=config.router_lr_scale,
)
dropless_moe = config.dropless_moe
Expand Down
12 changes: 10 additions & 2 deletions fast_llm/layers/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
super().__init__()
self._name = name

init_method_1 = init_normal_(std=config.init_method_std_mlp_1)
init_method_2 = init_normal_(std=config.init_method_std_mlp_2)
init_method_1 = init_normal_(
std=config.init_method_std_mlp_1,
min_val=config.init_method_min_mlp_1,
max_val=config.init_method_max_mlp_1,
)
init_method_2 = init_normal_(
std=config.init_method_std_mlp_2,
min_val=config.init_method_min_mlp_2,
max_val=config.init_method_max_mlp_2,
)

hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp)
Expand Down
16 changes: 12 additions & 4 deletions fast_llm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,23 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator)
init_ones_ = init_fill_(1.0)


def init_normal_(mean=0.0, std=1.0):
def init_normal_(mean=0.0, std=1.0, min_val=None, max_val=None):
def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa
return tensor.normal_(mean, std, generator=generator)
tensor = tensor.normal_(mean, std, generator=generator)
if min_val is not None or max_val is not None:
return tensor.clamp_(min=min_val, max=max_val) # noqa
else:
return tensor

return init_


def init_uniform_(low=0.0, high=1.0):
def init_uniform_(low=0.0, high=1.0, min_val=None, max_val=None):
def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa
return tensor.uniform_(low, high, generator=generator) # noqa
tensor = tensor.uniform_(low, high, generator=generator)
if min_val is not None or max_val is not None:
return tensor.clamp_(min=min_val, max=max_val) # noqa
else:
return tensor

return init_