Skip to content

Direct Preference Optimization (DPO) support #209

@tscholak

Description

@tscholak

🎯 Goal (What & Why)

Add support for Direct Preference Optimization (DPO). This enables fine-tuning with preference-labeled data using a stable, RL-free objective. DPO guides the model to prefer chosen completions over rejected ones for the same prompt. This will broaden Fast-LLM's training capabilities while preserving a unified dataset structure across pretraining, SFT, and DPO.

🚀 Execution Plan

Step 1: Implement simplified DPO without reference model

  • Extend the GPTSample dataclass with two optional string fields:

    chosen: str | None = None
    rejected: str | None = None

    i.e.

    Field Pretraining SFT DPO (new)
    text ✅ main input ✅ main input ✅ used as prompt
    loss_masking_spans ✅ used ❌ ignored
    chosen / rejected ✅ required
  • Update GPTMemmapDataset.__getitem__ to return GPTSample instances with chosen and rejected fields populated when present in the data.

  • Modify fast-llm prepare to handle DPO-style preference data:

    • Add CLI flags to fast-llm prepare to force DPO mode or auto-detect based on fields.
    • Expect fields text, chosen, rejected (names configurable) in the input dataset.
    • Store all three fields in the bin files for training-time access.
  • Add logic to ignore loss_masking_spans in DPO mode.

  • Add config options for DPO, including β (the temperature).

  • Implement DPO loss (simplified form):

    L = -log σ(β * log(p_theta(chosen | x) / p_theta(rejected | x)))

    where p_theta is the trained model and x is the prompt (text).

Design options

Packing

  1. Pack chosen, rejected pair from a single GPTSample into a sequence (similar to OpenRLHF)
  2. Pack multiple chosen/rejected pairs samples into a sequence
  3. Distinct sequences for chosen and rejected samples.

Preferred option is to go with 2., but if it gets too complicated we can start with 1. Either methods would require splitting the packed sequence using bos token for getting the chosen and rejected logprobs.

We would need to modify GPTSampledIndexedDataset. It may also be easier to create a new class for DPO with similar logic as used for padding.

Step 2 (optional): Add support for reference model normalization

  • Load frozen reference model (p_ref) alongside training model.

  • Update loss to full DPO form:

    L = -log σ(β * [log(p_theta(chosen | x) / p_theta(rejected | x)) - log(p_ref(chosen | x) / p_ref(rejected | x))])

  • Allow users to enable/disable this via config.

Offline computation of reference scores is possible but not recommended:

  • Adds preprocessing cost.
  • Ties dataset to a specific tokenizer/model.
  • Makes experimentation brittle.

📌 Acceptance Criteria

  • DPO loss is implemented and tested (initially without reference model).
  • fast-llm prepare can ingest datasets with text, chosen, and rejected fields and write them to bin format.
  • GPTMemmapDataset and GPTSample are extended to support DPO while remaining backwards-compatible.
  • Existing training code paths (pretraining, SFT) remain unaffected.
  • Documentation updated with format and usage examples.

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Medium).
  • Assign an owner when opening the issue.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request
No fields configured for Feature.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions