diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 70cf8806a..8e9594534 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -41,6 +41,15 @@ class LanguageModelLossConfig(Config): desc="Weight for this loss in the total loss computation.", valid=check_field(Assert.geq, 0.0), ) + logits_scale_factor: float = Field( + default=1.0, + hint=FieldHint.feature, + desc=( + "Extra logits scale factor applied for this loss only, stacked on top of the model's" + " `logits_scale_factor`." + ), + valid=check_field(Assert.gt, 0.0), + ) def get_layer( self, diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 3cab2bca8..034953b50 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -33,7 +33,7 @@ def __init__( self._prediction_heads = prediction_heads self._name = name self._num_splits = num_splits - self._logits_scale_factor = logits_scale_factor + self._logits_scale_factor = logits_scale_factor * self._config.logits_scale_factor self._weight = weight * self._config.weight self._do_register_loss = register_loss self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel