Skip to content

StarDoc model training#5

Closed
akshaykalkunte wants to merge 8 commits into
mainfrom
akshay/stardoc
Closed

StarDoc model training#5
akshaykalkunte wants to merge 8 commits into
mainfrom
akshay/stardoc

Conversation

@akshaykalkunte
Copy link
Copy Markdown
Contributor

WIP StarDoc model integration into FastLLM

@tscholak
Copy link
Copy Markdown
Collaborator

tscholak commented Nov 11, 2024

Hi @jlamypoirier! @akshaykalkunte and I talked and we want to push this PR over the finish line. There's a lot going on here, and we should review the approach top down to decide how this needs to be refactored to go into main. At the top of my head are the following separate concerns:

  1. Model architecture: Are VLMs GPTs from the point of view of Fast-LLM? I think they aren't because too much is different. We should add a new model architecture (e.g. "vlm") to Fast-LLM.
  2. Data preprocessing: Related to Add prepare command #38, we should factor out data preprocessing and introduce an offline preprocessing step, fast-llm prepare_data vlm --config stardoc.yaml, that makes VLMMemmapDatasets and stores them on disk.
  3. Vision encoder implementation: Right now it's a monolithic wrapper layer that uses a HF auto model. We should discuss if and when we reimplement this in Fast-LLM. This can be a separate effort and (as a side effect) result in yet another model class, vision_encoder, that we can also train from scratch if we wanted to.
  4. Cross-attention instead of adapter layer: StarDoc is moving towards a special form of cross-attention between the vision encoder and the LM decoder. This likely has implications for parallelization.
  5. Llama 3 support: StarDoc will use pre-trained Llama 3.2 (text-only?) models, we need to be able to load them. See also [feat] Llama 3.x rope scaling support #39.
  6. YAML configs: This PR currently doesn't support Fast-LLM's new YAML-based configs.

I think we can divide and conquer here.

@tscholak tscholak mentioned this pull request Nov 11, 2024
24 tasks
@tscholak
Copy link
Copy Markdown
Collaborator

tscholak commented Feb 7, 2025

4. Cross-attention instead of adapter layer: StarDoc is moving towards a special form of cross-attention between the vision encoder and the LM decoder. This likely has implications for parallelization.

As @akshaykalkunte pointed out recently, AlignVLM will be the best path forward for this first implementation. I read the paper and I don't see any obstacles. The method is refreshingly simple.

@tscholak
Copy link
Copy Markdown
Collaborator

tscholak commented May 9, 2025

I think it's time to close this one since we have #227

@tscholak tscholak closed this May 9, 2025
@jlamypoirier jlamypoirier deleted the akshay/stardoc branch September 19, 2025 01:08
jlamypoirier added a commit that referenced this pull request May 7, 2026
Splits the policy-gradient loss config and class hierarchy:

- LanguageModelPolicyGradientLossConfig (abstract base): shared fields
  (epsilon_low/high, metrics, normalize_by_documents, temperature).
- LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only
  use_triton).
- LanguageModelGSPOLossConfig: registers `type: gspo`.
- LanguageModelPolicyGradientLoss (abstract base): shared
  __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/
  get_preprocessing_config plumbing; abstract `_call_kernel`.
- LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements
  `_call_kernel` against its kernel; GSPO overrides
  `get_preprocessing_config` to add `return_document_index`.

Drops the stringly-typed `policy_loss: str` switch and the in-method
if/else dispatch, addressing review items #1 and #5 plus Note 2.

YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`.
No checked-in YAML configs use the old form.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jlamypoirier added a commit that referenced this pull request May 20, 2026
- Add `_sdp_dim`/`_sdp_active` to `LanguageModelLoss.__init__` so GSPO's
  SDP branch doesn't AttributeError on the first non-test call.
- Replace `document_index.max().item()` (and the SDP MAX all-reduce) with
  `len(kwargs[BlockKwargs.lengths])`: CPU-side, identical across SDP ranks,
  removes two GPU→CPU syncs per microbatch.
- Decorate `fused_gspo_loss_forward_backward` with `@torch.compile` for
  parity with GRPO. The `num_segments == 1` test case skips on CPU since
  torch._inductor's CPU codegen mishandles `index_add_` into a size-1
  buffer (atomic_add scatter).
- Make `divisor` a required arg on `fused_gspo_loss_forward_backward`:
  the wrapper always overrides it with the global document count, and
  the previous local-rank default would silently mis-normalize under SDP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants