Add GSPO loss#517
Open
jlamypoirier wants to merge 11 commits into
Open
Conversation
Group Sequence Policy Optimization: per-segment geometric-mean IS-ratio clipping. Mirrors GRPO's structure via shared abstract bases (LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss); the kernel matches GRPO except for a segment-aggregation block that produces per-segment R and A and broadcasts them back, so the softmax-chain backward is identical to GRPO. SDP-aware via optional all-reduce of segment sums; per-token weighting (mask / token_count_s) lets the SUM reduction at LossDef level give the canonical result without further correction. PyTorch kernel only; no Triton variant yet. Also lifts document_index_q/k from MixerKwargs to BlockKwargs so the LM head can read them without cross-namespace coupling, and renames grpo.py -> policy_gradient.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # fast_llm/layers/attention/config.py
- Add `_sdp_dim`/`_sdp_active` to `LanguageModelLoss.__init__` so GSPO's SDP branch doesn't AttributeError on the first non-test call. - Replace `document_index.max().item()` (and the SDP MAX all-reduce) with `len(kwargs[BlockKwargs.lengths])`: CPU-side, identical across SDP ranks, removes two GPU→CPU syncs per microbatch. - Decorate `fused_gspo_loss_forward_backward` with `@torch.compile` for parity with GRPO. The `num_segments == 1` test case skips on CPU since torch._inductor's CPU codegen mishandles `index_add_` into a size-1 buffer (atomic_add scatter). - Make `divisor` a required arg on `fused_gspo_loss_forward_backward`: the wrapper always overrides it with the global document count, and the previous local-rank default would silently mis-normalize under SDP. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add `sp_group` arg to fused_gspo_loss_forward_backward and all-reduce the three segment buffers over it when sequence-parallel shards the sequence across the TP group; otherwise per-segment ratios use partial sums and produce silent corruption under SP. Wrapper passes `self._parallel_dim.group` when `_sequence_parallel` is active. - Wire `num_labels_in_seq` through the GSPO test and assert `new_logprobs_fused` against the reference. Required aligning the reference to use scaled logits for new_logprobs (reusing `target_log_probabilities`), matching the kernel's behavior of reporting the loss-path log-probs. - Drop the unreachable `max(num_segments, 1)` guard in the GSPO reference and the matching `divisor=max(num_segments, 1)` at the test call site. SDP all-reduce branch coverage (review item 3) deferred to a follow-up adding a `gspo_loss` flag to `tests/layers/test_lm_head.py` alongside the existing GRPO config, with an SDP distributed variant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Rename `_sdp_dim`/`_sdp_active` to `_sequence_data_dim`/`_sequence_data_active` for codebase consistency and to avoid SP/SDP confusion now that `sp_group` exists adjacent in the kernel signature. - Rename per-token broadcasts to spell out their role: `flat_doc` → `flat_document_index`, `seg_advantage` → `advantage_per_token`, `inv_token_count` → `inverse_token_count`. - Drop section-header comments, the `# Broadcast back to per-token` line, and the redundant `.to(log_ratio.dtype)` on the already-typed product in the `log_ratio_sum` index_add. - Trim the kernel docstring middle paragraph to one line; collapse the `LanguageModelPolicyGradientLoss` docstring to its summary. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The original PR added the GSPO loss class and kernel-level tests but left no model-level integration coverage. This mirrors the existing GRPO wiring: - `tests/utils/model_configs.py`: add `llama_gspo` config alongside `llama_grpo` so the full model test matrix exercises GSPO end-to-end. - `tests/layers/test_lm_head.py`: add a `gspo_loss` flag with matching kwargs setup (document_index_q, lengths, per-document label_counts) and a hand-computed reference for both the loss and the per-doc new_logprobs metric. Configs added only with `num_splits=1`. Assert num_splits == 1 in `LanguageModelGSPOLoss.__init__` — `cross_entropy_splits` chunks the sequence across calls and per-segment aggregation can't recombine across chunks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the kernel's locally-computed `token_count` (built via `index_add_` over the local fragment) with the preprocessor's `num_labels_in_seq` (per-document labeled-token count broadcast per token). Two payoffs: - `token_weight = mask / num_labels_in_seq` ensures all tokens of a document sum to 1 across SDP/SP shards, so the `/ num_documents_in_batch` divisor matches the segment count without the prior local-vs-global mismatch. - The geometric-mean denominator no longer requires per-segment extraction; the kernel pre-divides per-token contributions by `num_labels_in_seq`, then segment-aggregates. Mathematically equivalent to `log_ratio_sum / N_d` (since `N_d` is constant within a segment), and removes one index_add and one all-reduce. Documents must still be visible to a single kernel call (modulo SDP/SP). The only mechanism that violates this is `schedule.micro_batch_splits > 1`, which produces partial per-fragment `exp(mean)` values that can't be linearly recombined into the whole-document ratio. `llama_gspo` skips the `ms*` distributed test variants; the constraint is documented in the kernel docstring. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make the 0-based convention explicit at the kernel boundary so a future caller passing 1-based `BlockKwargs.document_index_q` directly fails loudly instead of feeding an off-by-one index into `index_add_`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Dynamo lifts the Python-int `divisor` and `num_segments` args to symbolic ints with no concrete hint at trace time, then trips on `grad_output / divisor` with `ZeroDivisionError` (hint=0). The Triton kernel in the follow-up PR is the actual GPU perf path; the eager PyTorch fallback runs without torch.compile and no longer needs the CPU-only `skipif` for the `num_segments == 1` case (that workaround was specific to inductor's CPU codegen). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The kernel's per-segment SDP/SP all-reduce was broken because
`document_index_q` produced by `LengthModelInputPreprocessor` uses
**local** (per-rank-cropped) document IDs — same doc gets different
IDs on different ranks, so the cross-rank `index_add_` collation
silently mixed unrelated documents (and hung on size-mismatched
buffers when local lengths differed across ranks).
Fix:
- Add `BlockKwargs.global_document_index_q` and `num_documents_in_sequence`,
computed in `TokenBatch._get_model_input` from the unsliced batch's
cumulative lengths. The IDs are 1-based (matching `document_index_q`'s
convention) and consistent across all SDP/SP ranks within a DP rank.
- GSPO consumes the new fields. `num_segments` now uses
`num_documents_in_sequence` (DP-rank-local whole-doc count), so per-segment
buffers are sized identically across SDP/SP for a well-defined all-reduce.
- `_get_model_input.num_documents` setter now triggers on `is_first_for_rank`
rather than `begin == 0`, so SDP ranks > 0 contribute to the
`num_documents_in_batch` DP allreduce instead of zero-ing it out.
- `llama_gspo` `skip_tests` widened to `("ms", "ce", "bf16")`:
- `ms*`/`ce*`: split sequences across kernel calls — `exp(mean)` can't
be reconstructed from per-fragment values (architectural limitation).
- `bf16`: numerical noise on tiny log-ratio values amplified by
`exp(mean)` exceeds the default vs-fp32 tolerance.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
LanguageModelPolicyGradientLossConfig/LanguageModelPolicyGradientLoss) for the common scaffolding (epsilon_low/high, new_logprobs metric, preprocessing flags) and refactors GRPO to inherit from them.document_index_q/document_index_kfromMixerKwargsup toBlockKwargsso the LM head can read them without depending on attention-namespaced keys.fast_llm/layers/language_model/loss/grpo.pytopolicy_gradient.py(now contains both losses + shared base).Design notes
The GSPO kernel structurally parallels
fused_grpo_loss_forward_backward:log_ratio,advantages, and counts into per-segment buffers; optional SDP all-reduce of the three buffers; computeR_s = exp(lrn_sum / token_count)andA_s = adv_sum / token_count(detached).R_s,A_sback to per-token; per-token loss weight ismask / token_count_sso each segment contributes once to the sum.probability_ratio = R_{s(t)}and per-token advantages replaced byA_{s(t)}.The per-token decomposition gets SDP correctness "for free": each rank only sums contributions from its own tokens, so SUM-reducing at the
LossDeflevel reproduces the canonical single-rank result with no/sdp_sizecorrection.No Triton variant yet — comes in a follow-up.
Test plan
pytest tests/layers/test_lm_losses.py::test_gspo_loss— 20 cases (10 param sets × 2 batch shapes)pytest tests/layers/test_lm_losses.py::test_grpo_loss— 20 cases pass after refactorpytest tests/layers/test_lm_losses.py::test_grpo_metrics— 40 cases pass after rename