Skip to content

Add GSPO loss#517

Open
jlamypoirier wants to merge 11 commits into
mainfrom
jlp_gspo
Open

Add GSPO loss#517
jlamypoirier wants to merge 11 commits into
mainfrom
jlp_gspo

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

  • Adds Group Sequence Policy Optimization (per-segment geometric-mean IS-ratio clipping) as a sibling to GRPO. Sequence-level surrogate, computed as a per-token sum so the softmax-chain backward and SDP partitioning fall out identically to GRPO.
  • Extracts shared abstract bases (LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss) for the common scaffolding (epsilon_low/high, new_logprobs metric, preprocessing flags) and refactors GRPO to inherit from them.
  • Lifts document_index_q/document_index_k from MixerKwargs up to BlockKwargs so the LM head can read them without depending on attention-namespaced keys.
  • Renames fast_llm/layers/language_model/loss/grpo.py to policy_gradient.py (now contains both losses + shared base).

Design notes

The GSPO kernel structurally parallels fused_grpo_loss_forward_backward:

  1. Same softmax → predicted_logits → new_log_probs setup.
  2. New mid-kernel block: scatter-add per-token log_ratio, advantages, and counts into per-segment buffers; optional SDP all-reduce of the three buffers; compute R_s = exp(lrn_sum / token_count) and A_s = adv_sum / token_count (detached).
  3. Broadcast R_s, A_s back to per-token; per-token loss weight is mask / token_count_s so each segment contributes once to the sum.
  4. Downstream loss and the softmax-chain backward are line-for-line identical to GRPO with probability_ratio = R_{s(t)} and per-token advantages replaced by A_{s(t)}.

The per-token decomposition gets SDP correctness "for free": each rank only sums contributions from its own tokens, so SUM-reducing at the LossDef level reproduces the canonical single-rank result with no /sdp_size correction.

No Triton variant yet — comes in a follow-up.

Test plan

  • pytest tests/layers/test_lm_losses.py::test_gspo_loss — 20 cases (10 param sets × 2 batch shapes)
  • pytest tests/layers/test_lm_losses.py::test_grpo_loss — 20 cases pass after refactor
  • pytest tests/layers/test_lm_losses.py::test_grpo_metrics — 40 cases pass after rename
  • End-to-end RL training run (not exercised by this PR)

Group Sequence Policy Optimization: per-segment geometric-mean IS-ratio
clipping. Mirrors GRPO's structure via shared abstract bases
(LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss);
the kernel matches GRPO except for a segment-aggregation block that
produces per-segment R and A and broadcasts them back, so the softmax-chain
backward is identical to GRPO. SDP-aware via optional all-reduce of
segment sums; per-token weighting (mask / token_count_s) lets the SUM
reduction at LossDef level give the canonical result without further
correction. PyTorch kernel only; no Triton variant yet.

Also lifts document_index_q/k from MixerKwargs to BlockKwargs so the LM
head can read them without cross-namespace coupling, and renames
grpo.py -> policy_gradient.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 19, 2026
jlamypoirier and others added 10 commits May 19, 2026 17:46
# Conflicts:
#	fast_llm/layers/attention/config.py
- Add `_sdp_dim`/`_sdp_active` to `LanguageModelLoss.__init__` so GSPO's
  SDP branch doesn't AttributeError on the first non-test call.
- Replace `document_index.max().item()` (and the SDP MAX all-reduce) with
  `len(kwargs[BlockKwargs.lengths])`: CPU-side, identical across SDP ranks,
  removes two GPU→CPU syncs per microbatch.
- Decorate `fused_gspo_loss_forward_backward` with `@torch.compile` for
  parity with GRPO. The `num_segments == 1` test case skips on CPU since
  torch._inductor's CPU codegen mishandles `index_add_` into a size-1
  buffer (atomic_add scatter).
- Make `divisor` a required arg on `fused_gspo_loss_forward_backward`:
  the wrapper always overrides it with the global document count, and
  the previous local-rank default would silently mis-normalize under SDP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add `sp_group` arg to fused_gspo_loss_forward_backward and all-reduce the
  three segment buffers over it when sequence-parallel shards the sequence
  across the TP group; otherwise per-segment ratios use partial sums and
  produce silent corruption under SP. Wrapper passes `self._parallel_dim.group`
  when `_sequence_parallel` is active.
- Wire `num_labels_in_seq` through the GSPO test and assert
  `new_logprobs_fused` against the reference. Required aligning the reference
  to use scaled logits for new_logprobs (reusing `target_log_probabilities`),
  matching the kernel's behavior of reporting the loss-path log-probs.
- Drop the unreachable `max(num_segments, 1)` guard in the GSPO reference and
  the matching `divisor=max(num_segments, 1)` at the test call site.

SDP all-reduce branch coverage (review item 3) deferred to a follow-up adding
a `gspo_loss` flag to `tests/layers/test_lm_head.py` alongside the existing
GRPO config, with an SDP distributed variant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Rename `_sdp_dim`/`_sdp_active` to `_sequence_data_dim`/`_sequence_data_active`
  for codebase consistency and to avoid SP/SDP confusion now that `sp_group`
  exists adjacent in the kernel signature.
- Rename per-token broadcasts to spell out their role:
  `flat_doc` → `flat_document_index`,
  `seg_advantage` → `advantage_per_token`,
  `inv_token_count` → `inverse_token_count`.
- Drop section-header comments, the `# Broadcast back to per-token` line,
  and the redundant `.to(log_ratio.dtype)` on the already-typed product
  in the `log_ratio_sum` index_add.
- Trim the kernel docstring middle paragraph to one line; collapse the
  `LanguageModelPolicyGradientLoss` docstring to its summary.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The original PR added the GSPO loss class and kernel-level tests but
left no model-level integration coverage. This mirrors the existing
GRPO wiring:

- `tests/utils/model_configs.py`: add `llama_gspo` config alongside
  `llama_grpo` so the full model test matrix exercises GSPO end-to-end.
- `tests/layers/test_lm_head.py`: add a `gspo_loss` flag with matching
  kwargs setup (document_index_q, lengths, per-document label_counts)
  and a hand-computed reference for both the loss and the per-doc
  new_logprobs metric. Configs added only with `num_splits=1`.

Assert num_splits == 1 in `LanguageModelGSPOLoss.__init__` —
`cross_entropy_splits` chunks the sequence across calls and per-segment
aggregation can't recombine across chunks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the kernel's locally-computed `token_count` (built via `index_add_`
over the local fragment) with the preprocessor's `num_labels_in_seq`
(per-document labeled-token count broadcast per token). Two payoffs:

- `token_weight = mask / num_labels_in_seq` ensures all tokens of a
  document sum to 1 across SDP/SP shards, so the
  `/ num_documents_in_batch` divisor matches the segment count without
  the prior local-vs-global mismatch.
- The geometric-mean denominator no longer requires per-segment
  extraction; the kernel pre-divides per-token contributions by
  `num_labels_in_seq`, then segment-aggregates. Mathematically
  equivalent to `log_ratio_sum / N_d` (since `N_d` is constant within a
  segment), and removes one index_add and one all-reduce.

Documents must still be visible to a single kernel call (modulo
SDP/SP). The only mechanism that violates this is
`schedule.micro_batch_splits > 1`, which produces partial per-fragment
`exp(mean)` values that can't be linearly recombined into the
whole-document ratio. `llama_gspo` skips the `ms*` distributed test
variants; the constraint is documented in the kernel docstring.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make the 0-based convention explicit at the kernel boundary so a future
caller passing 1-based `BlockKwargs.document_index_q` directly fails
loudly instead of feeding an off-by-one index into `index_add_`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Dynamo lifts the Python-int `divisor` and `num_segments` args to
symbolic ints with no concrete hint at trace time, then trips on
`grad_output / divisor` with `ZeroDivisionError` (hint=0). The Triton
kernel in the follow-up PR is the actual GPU perf path; the eager
PyTorch fallback runs without torch.compile and no longer needs the
CPU-only `skipif` for the `num_segments == 1` case (that workaround
was specific to inductor's CPU codegen).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The kernel's per-segment SDP/SP all-reduce was broken because
`document_index_q` produced by `LengthModelInputPreprocessor` uses
**local** (per-rank-cropped) document IDs — same doc gets different
IDs on different ranks, so the cross-rank `index_add_` collation
silently mixed unrelated documents (and hung on size-mismatched
buffers when local lengths differed across ranks).

Fix:
- Add `BlockKwargs.global_document_index_q` and `num_documents_in_sequence`,
  computed in `TokenBatch._get_model_input` from the unsliced batch's
  cumulative lengths. The IDs are 1-based (matching `document_index_q`'s
  convention) and consistent across all SDP/SP ranks within a DP rank.
- GSPO consumes the new fields. `num_segments` now uses
  `num_documents_in_sequence` (DP-rank-local whole-doc count), so per-segment
  buffers are sized identically across SDP/SP for a well-defined all-reduce.
- `_get_model_input.num_documents` setter now triggers on `is_first_for_rank`
  rather than `begin == 0`, so SDP ranks > 0 contribute to the
  `num_documents_in_batch` DP allreduce instead of zero-ing it out.
- `llama_gspo` `skip_tests` widened to `("ms", "ce", "bf16")`:
  - `ms*`/`ce*`: split sequences across kernel calls — `exp(mean)` can't
    be reconstructed from per-fragment values (architectural limitation).
  - `bf16`: numerical noise on tiny log-ratio values amplified by
    `exp(mean)` exceeds the default vs-fp32 tolerance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant