diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 1f4ac1a17..cbe17101e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -197,6 +197,12 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) + if self.transformer.per_layer_lr_scale is not None: + # -1 because the first prediction head's transformer layer is accounted for in num_layers + # +1 because the layer index starts at 1 + Assert.eq( + len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + ) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: self.transformer.setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 55b38ea76..3351c9906 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -82,7 +82,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) + # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 49778c63f..a46af1387 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -44,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, name, layer_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index c4d8afdc7..b01eb2aa5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -71,7 +71,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, name, layer_index) def forward( self, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1c5eb1406..582575c01 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers, 1), + layer_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1,