Skip to content

Fix backward compatibility in normalization#3

Merged
jlamypoirier merged 1 commit into
mainfrom
flat_backward_compatible
Oct 16, 2024
Merged

Fix backward compatibility in normalization#3
jlamypoirier merged 1 commit into
mainfrom
flat_backward_compatible

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented Oct 16, 2024

Hopefully this works.

@jlamypoirier jlamypoirier merged commit 87f23a0 into main Oct 16, 2024
@jlamypoirier jlamypoirier deleted the flat_backward_compatible branch October 16, 2024 14:37
@tscholak tscholak added this to the 0.2.0 milestone Oct 25, 2024
jlamypoirier added a commit that referenced this pull request May 20, 2026
- Add `_sdp_dim`/`_sdp_active` to `LanguageModelLoss.__init__` so GSPO's
  SDP branch doesn't AttributeError on the first non-test call.
- Replace `document_index.max().item()` (and the SDP MAX all-reduce) with
  `len(kwargs[BlockKwargs.lengths])`: CPU-side, identical across SDP ranks,
  removes two GPU→CPU syncs per microbatch.
- Decorate `fused_gspo_loss_forward_backward` with `@torch.compile` for
  parity with GRPO. The `num_segments == 1` test case skips on CPU since
  torch._inductor's CPU codegen mishandles `index_add_` into a size-1
  buffer (atomic_add scatter).
- Make `divisor` a required arg on `fused_gspo_loss_forward_backward`:
  the wrapper always overrides it with the global document count, and
  the previous local-rank default would silently mis-normalize under SDP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants