StarDoc model training#5
Conversation
|
Hi @jlamypoirier! @akshaykalkunte and I talked and we want to push this PR over the finish line. There's a lot going on here, and we should review the approach top down to decide how this needs to be refactored to go into
I think we can divide and conquer here. |
As @akshaykalkunte pointed out recently, AlignVLM will be the best path forward for this first implementation. I read the paper and I don't see any obstacles. The method is refreshingly simple. |
|
I think it's time to close this one since we have #227 |
Splits the policy-gradient loss config and class hierarchy: - LanguageModelPolicyGradientLossConfig (abstract base): shared fields (epsilon_low/high, metrics, normalize_by_documents, temperature). - LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only use_triton). - LanguageModelGSPOLossConfig: registers `type: gspo`. - LanguageModelPolicyGradientLoss (abstract base): shared __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/ get_preprocessing_config plumbing; abstract `_call_kernel`. - LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements `_call_kernel` against its kernel; GSPO overrides `get_preprocessing_config` to add `return_document_index`. Drops the stringly-typed `policy_loss: str` switch and the in-method if/else dispatch, addressing review items #1 and #5 plus Note 2. YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`. No checked-in YAML configs use the old form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add `_sdp_dim`/`_sdp_active` to `LanguageModelLoss.__init__` so GSPO's SDP branch doesn't AttributeError on the first non-test call. - Replace `document_index.max().item()` (and the SDP MAX all-reduce) with `len(kwargs[BlockKwargs.lengths])`: CPU-side, identical across SDP ranks, removes two GPU→CPU syncs per microbatch. - Decorate `fused_gspo_loss_forward_backward` with `@torch.compile` for parity with GRPO. The `num_segments == 1` test case skips on CPU since torch._inductor's CPU codegen mishandles `index_add_` into a size-1 buffer (atomic_add scatter). - Make `divisor` a required arg on `fused_gspo_loss_forward_backward`: the wrapper always overrides it with the global document count, and the previous local-rank default would silently mis-normalize under SDP. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WIP StarDoc model integration into FastLLM