Skip to content

Per-loss logits_scale_factor#516

Merged
jlamypoirier merged 1 commit into
mainfrom
jlp_loss-logits-scale-factor
May 15, 2026
Merged

Per-loss logits_scale_factor#516
jlamypoirier merged 1 commit into
mainfrom
jlp_loss-logits-scale-factor

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

Add a logits_scale_factor field to LanguageModelLossConfig, applied on top of the head's logits_scale_factor for that loss only. Every loss subclass picks it up automatically via self._logits_scale_factor — no per-subclass changes.

Primary use case: in RL losses, set to 1 / actor_temperature so new log-probabilities are computed at the same scale as the actor's stored old log-probabilities (importance ratio at step 0 is no longer offset by the actor's sampling temperature).

Reimplemented from PR #502's temperature field on the GRPO config: the field is moved to the base config (so all losses can opt in), renamed to match the existing head field, and given multiplier semantics (extra scale, default 1.0, stacked on top of the head's scale).

Test plan

  • Existing loss tests still pass at default logits_scale_factor=1.0.
  • Manual: setting logits_scale_factor=2.0 on a GRPO loss config produces softmax outputs equivalent to doubling the head's logits_scale_factor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier merged commit 8c6b67c into main May 15, 2026
2 of 3 checks passed
@jlamypoirier jlamypoirier deleted the jlp_loss-logits-scale-factor branch May 15, 2026 22:28
@jlamypoirier jlamypoirier mentioned this pull request May 15, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant