From 9ddfb69115d6d84d960e76c36d4e65994eae76cc Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Apr 2025 21:43:23 +0000 Subject: [PATCH 1/9] add per-layer lr-scale --- fast_llm/layers/common/config.py | 3 ++- fast_llm/layers/common/normalization.py | 5 +++++ fast_llm/layers/language_model/config.py | 13 +++++++++++++ fast_llm/layers/language_model/embedding.py | 2 ++ fast_llm/layers/language_model/head.py | 1 + fast_llm/layers/transformer/attention.py | 13 ++++++++----- fast_llm/layers/transformer/config.py | 6 ++++++ .../layers/transformer/mixture_of_experts.py | 9 ++++++--- fast_llm/layers/transformer/mlp.py | 15 ++++++++++----- fast_llm/layers/transformer/transformer.py | 5 +++-- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 18 ++++++++++++++++++ 12 files changed, 75 insertions(+), 17 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 71c15c9b8..6e596751d 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -82,7 +82,7 @@ class NormalizationConfig(NormalizationArchitectureConfig, BaseModelConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ @@ -91,6 +91,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": "eps": self.epsilon, "implementation": self.implementation, "zero_centered": self.zero_centered, + "lr_scale": lr_scale, } if self.initialization_range: mean = 0 if self.zero_centered else 1 diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 04123014e..984778f83 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -152,6 +152,7 @@ def __init__( weight_init_method=None, bias_init_method=init_zeros_, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -190,12 +191,14 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape @@ -230,6 +233,7 @@ def __init__( implementation: NormalizationImplementation = NormalizationImplementation.auto, weight_init_method=None, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -263,6 +267,7 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=True, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b4b4e187c..c99ee4f6a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -202,6 +202,19 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..e0386d8df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c2974415d..1153fb2c2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -102,6 +102,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5c..54fff2286 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -84,7 +84,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer @@ -110,6 +110,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, @@ -118,7 +121,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -127,7 +130,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -139,7 +142,7 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf409e773..c13c2a093 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -636,6 +636,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.", diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f4..49778c63f 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -21,7 +21,7 @@ from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." @@ -59,6 +59,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + self.router = Linear( tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), @@ -66,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max ), - lr_scale=config.router_lr_scale, + lr_scale=router_lr_scale, ) dropless_moe = config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f9..c4d8afdc7 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -10,13 +10,14 @@ from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): super().__init__() self._name = name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -38,6 +39,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, @@ -45,7 +50,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) # PEFT. @@ -64,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 92df18937..9e1e0bcfa 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -39,8 +39,9 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) self._create_mixer() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a7ec58d67..873c8f80e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -80,7 +80,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=self._config.transformer.num_layers, + layer_index=self._config.transformer.num_layers + i, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index a8c5eac61..c524a315d 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -326,3 +326,21 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def get_lr_scale( + lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None +) -> float | None | tuple[float | None, ...]: + """ + Combine module and layer lr_scale. + If one is None, return the other. + """ + if lr_scale is None: + return layer_lr_scale + if layer_lr_scale is None: + return lr_scale + if isinstance(lr_scale, float): + return lr_scale * layer_lr_scale + if isinstance(lr_scale, tuple): + return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) + raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") From 77ad39f7730f314d01c3c8f5da14f1ac8aabf5f4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 30 Apr 2025 20:21:08 +0000 Subject: [PATCH 2/9] add token-prediction loss coefficients --- fast_llm/layers/language_model/config.py | 10 ++++++++++ fast_llm/layers/language_model/head.py | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c99ee4f6a..c675361a2 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,6 +215,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): doc="May be used to freeze the output weights by setting their scale to zero.", hint=FieldHint.feature, ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -231,3 +237,7 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1153fb2c2..014a617cb 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -57,6 +57,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._loss_coefficient = ( + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = config.transformer.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor @@ -133,7 +136,7 @@ def forward( else: if self.training: # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient) # MTP: Return shared_hidden to be used by the next head. return shared_hidden From 41d4da3491faa7ac7c1a6bd599be4fa41b97feeb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 5 May 2025 21:50:05 +0000 Subject: [PATCH 3/9] disable freezing --- fast_llm/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 849307563..611eb9f48 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,9 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # TODO: re-enable when fixed? + # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 9c4f38f92c26c4a3a44ab67795f9dd3b58840245 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 7 May 2025 15:39:51 +0000 Subject: [PATCH 4/9] layer-lr scale for mlp as well --- fast_llm/layers/transformer/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 9e1e0bcfa..982381720 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -46,7 +46,7 @@ def __init__( self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index ) # PEFT. From 6fe2b6d754734bc18dd04cb2d45e2d2257d6619d Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 12 May 2025 20:41:23 +0000 Subject: [PATCH 5/9] add check for length of per_layer_lr_scale --- fast_llm/layers/language_model/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c675361a2..93d1496d9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -241,3 +241,9 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) + if self.transformer.per_layer_lr_scale is not None: + # -1 because the first prediction head's transformer layer is accounted for in num_layers + # +1 because the layer index starts at 1 + Assert.eq( + len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + ) From 83baeefb74d9ee98c868a2e133b7c1f91a0c45f7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 13 May 2025 20:59:19 +0000 Subject: [PATCH 6/9] re-enable freezing --- fast_llm/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 611eb9f48..c82a3bf18 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -235,8 +235,8 @@ def __init__( self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) # TODO: re-enable when fixed? - # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - self.requires_grad = requires_grad + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From e834be76c1974612d37d5b8e939102b085aadabd Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 14 May 2025 15:47:27 +0000 Subject: [PATCH 7/9] pass layer-index to mlp --- fast_llm/layers/transformer/mixture_of_experts.py | 2 +- fast_llm/layers/transformer/mlp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 49778c63f..a46af1387 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -44,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, name, layer_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index c4d8afdc7..b01eb2aa5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -71,7 +71,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, name, layer_index) def forward( self, From f2f5265af9ec75fb76493f4364ec9420f95d6f97 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 16 Jun 2025 17:28:14 +0000 Subject: [PATCH 8/9] remove comments --- fast_llm/tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index c82a3bf18..849307563 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,9 +234,7 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: re-enable when fixed? self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 6622040e10a967a6bb0703ee698c44e1d730c061 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 16 Jun 2025 19:02:28 +0000 Subject: [PATCH 9/9] remove per_layer_lr_scale in transformer config (already in llmblock) --- fast_llm/layers/transformer/config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 3e710330d..54772e496 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -455,12 +455,6 @@ class TransformerConfig(LLMBlockConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.",