diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index c1006ad15..90e76a669 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -29,7 +29,7 @@ def z_loss( logits_scale_factor: float = 1.0, ) -> torch.Tensor: if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logit_scale_factor=logits_scale_factor) + loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) if losses is not None and loss_name is not None: losses[loss_name].append(loss.detach()) if training and grad_scale is not None: