Skip to content

Add SDPO recipe with SciKnowEval support#419

Draft
YujiaBao wants to merge 21 commits intothinking-machines-lab:mainfrom
YujiaBao:sdpo-recipe-v2
Draft

Add SDPO recipe with SciKnowEval support#419
YujiaBao wants to merge 21 commits intothinking-machines-lab:mainfrom
YujiaBao:sdpo-recipe-v2

Conversation

@YujiaBao
Copy link
Copy Markdown
Member

@YujiaBao YujiaBao commented Mar 10, 2026

Self-Distilled Policy Optimization (SDPO)

Implements the SDPO algorithm from Huebotter et al., 2026 (official implementation) as a tinker cookbook recipe.

What is SDPO?

Standard RL algorithms like GRPO assign a single scalar reward per sequence — every token in a correct solution gets the same credit, and every token in an incorrect one gets the same blame. SDPO provides dense, token-level credit assignment by using the model's own successful solutions as a teaching signal.

The key insight (Proposition 2.1): the SDPO gradient is a policy gradient where each token gets its own advantage, computed as the log-ratio of a solution-conditioned teacher to the student:

advantage_t = log π_teacher(y_t | prompt, solution, y_<t) − log π_student(y_t | prompt, y_<t)

Tokens where the teacher is more confident than the student get positive advantage (reinforced), and vice versa. This maps directly to tinker's importance_sampling loss — the same pattern used by tinker's online distillation — so we avoid forward_backward_custom entirely.

Algorithm flow

For each batch, the training loop repeats the following steps:

Step 1 — Rollout. Sample group_size (default 8) responses per problem using the current policy.

Step 2 — Identify successes. Check which responses solved the problem correctly (e.g. correct \boxed{} answer for math, correct MCQ letter for SciKnowEval).

Step 3 — Build teacher prompt. Condition the reference model on information the student doesn't have. This can include:

  • A successful solution from another rollout in the group (primary signal for math/MCQ).
  • Environment feedback from the trajectory's own execution, e.g. compiler errors or failing test cases (useful for code tasks; enable with include_environment_feedback=True).
  • Both — the paper (Table 6) shows these are complementary: solution alone 42.6%, feedback alone 39.9%, both together 48.3%.

Step 4 — Compute teacher logprobs. Use the frozen reference model to score each trajectory's response tokens under this conditioned teacher prompt.

Step 5 — Build advantages. For each response token, set advantage_t = teacher_lp_t − student_lp_t. Tokens where the teacher is more confident get positive advantages (reinforced); tokens where the student already matches get near-zero advantages.

Step 6 — Train. Call forward_backward(datums, loss_fn="importance_sampling") followed by optim_step(). The importance weight automatically corrects for any off-policy drift. Refresh the sampling client with the updated weights.

Teacher regularization

We use a frozen reference model (θ_ref) as the teacher base. Table 4 in the paper shows this achieves 48.8 accuracy vs 36.1 for the unregularized variant, while being simpler than EMA (49.3) or trust-region (50.6).

Results on MATH

Using Qwen3-8B on Hendrycks MATH (MATH-500 eval), SDPO improves from 43.0% to 63.0% after 180 steps:

Step test/env/all/correct sdpo/success_fraction sdpo/mean_group_success_rate
0 0.430 0.547 0.408
180 0.630 0.781 0.645

Code structure

tinker_cookbook/sdpo/                  # Reusable core
├── data.py                           # build_sdpo_datum (advantages = teacher_lp - student_lp)
├── teacher.py                        # Teacher prompt construction + logprob computation
└── train.py                          # Config + main() training loop

tinker_cookbook/recipes/sdpo/          # Recipe CLI
├── train.py                          # Thin CLI wrapper
├── sciknoweval_env.py                # Paper's SciKnowEval MCQ environment
└── README.md

Environments

  • SciKnowEval — the paper's primary benchmark (hicai-zju/SciKnowEval), MCQ science questions with domains: chemistry, physics, biology, material
  • MATH — Hendrycks MATH (12k train / MATH-500 test)
  • GSM8K, Polaris, DeepMath — reused from math_rl recipe

Differences from the paper

  • Token-level vs full-logit distillation: This is the main gap. The paper computes JSD (α=0.5) over the full vocabulary distribution at each position (with top-k=100 approximation), which tells the student "redistribute your probability mass across all tokens to match the teacher." Our implementation only computes the advantage on the sampled token (teacher_lp(y_t) - student_lp(y_t)) via compute_logprobs, which misses the teacher's signal about all the other tokens in the vocabulary. Closing this gap would require tinker's forward pass to return top-k logits (not just the logprob of the target token), which would allow us to compute a proper KL over the approximate vocabulary distribution.
  • LoRA vs full fine-tuning: The paper uses full fine-tuning; this recipe uses LoRA (rank 32) for efficiency.

Usage

# SciKnowEval (paper's benchmark)
uv run python -m tinker_cookbook.recipes.sdpo.train \
    model_name="Qwen/Qwen3-8B" env=sciknoweval \
    sciknoweval_domain=chemistry group_size=8 \
    groups_per_batch=32 learning_rate=1e-5 max_tokens=8192

# MATH
uv run python -m tinker_cookbook.recipes.sdpo.train \
    model_name="Qwen/Qwen3-8B" env=math \
    group_size=16 groups_per_batch=64 \
    learning_rate=2e-5 max_tokens=512

# Compare with GRPO baseline (same settings)
uv run python -m tinker_cookbook.recipes.math_rl.train \
    env=math model_name="Qwen/Qwen3-8B" \
    group_size=16 groups_per_batch=64 \
    learning_rate=2e-5 max_tokens=512

Metric names match the GRPO recipe (env/all/correct, env/all/reward/total, test/env/all/correct) so curves overlay directly in W&B.

Test plan

  • 17 unit tests passing (uv run pytest tinker_cookbook/tests/test_sdpo.py)
  • Existing tests unaffected (139 passed)
  • End-to-end run on MATH (43% → 63% after 180 steps)
  • End-to-end run on SciKnowEval
  • Side-by-side SDPO vs GRPO comparison with matched settings

🤖 Generated with Claude Code

YujiaBao and others added 13 commits March 16, 2026 13:55
Implements the SDPO algorithm (arXiv:2601.20802) as a cookbook recipe.
SDPO augments on-policy RL by distilling from the model's own successful
trajectories, providing dense token-level credit assignment via KL
divergence between student and solution-conditioned teacher.

Uses frozen reference model as teacher (Table 4: 48.8 vs 36.1
unregularized) with token-level reverse KL loss. Defaults to
Qwen/Qwen3-8B on Hendrycks MATH, matching the paper's setup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tructure

Major changes:
- Replace forward_backward_custom with forward_backward + importance_sampling
  by encoding SDPO signal as advantages (teacher_lp - student_lp). This avoids
  the 1.5-3x overhead and follows the same pattern as tinker's online
  distillation. The importance weight also handles off-policy correction.
- Split monolithic train.py into reusable modules under tinker_cookbook/sdpo/
  (loss, data, teacher, train) with a thin CLI wrapper in recipes/sdpo/.
- Add SciKnowEval MCQ environment (hicai-zju/SciKnowEval) matching the
  paper's primary benchmark.
- Add dont_reprompt_on_self_success and remove_thinking_from_demonstration
  flags matching the paper's defaults.
- Fix sync-in-async warnings by using async API variants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- ruff format all SDPO files
- Fix pyright error in sciknoweval_env.py (row type annotation)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove sdpo/loss.py (dead code with incompatible datum format)
- Fix groups with 1 success producing zero datums: fall back to self
  when dont_reprompt_on_self_success=True but no other success exists
- Guard against None in compute_logprobs return values
- Deduplicate _gather_with_progress by importing from rl.train
- Add SciKnowEval grading tests (_extract_answer, _format_choices)
- Fix unused imports (ruff)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Qwen3-8B: 43.0% -> 63.0% on MATH-500 after 180 steps.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The paper conditions the teacher on both successful solutions AND environment
feedback (e.g. compiler errors, failing test cases). Table 6 shows these are
complementary: solution alone 42.6%, feedback alone 39.9%, both 48.3%.

- Add extract_feedback() to pull feedback from transition logs
- Update build_teacher_prompt() to accept optional solution and feedback
- Add include_environment_feedback config flag (default False)
- Environments can provide feedback via StepResult.logs["feedback"]
- Update docstrings to reflect feedback support

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove double test/ prefix on eval metrics (RLTestSetEvaluator already
  adds test/ prefix)
- Add training rollout metrics (env/all/correct, env/all/reward/total)
  without prefix, matching GRPO convention so curves overlay in W&B

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
YujiaBao and others added 8 commits March 17, 2026 15:48
Wire SDPO to the DeepCoder (LiveCodeBench v6) dataset via `env=code_rl`,
matching the paper's single-turn code generation setup. The model writes
code once, the sandbox executes it, and execution feedback (compiler
errors, test failures) conditions the teacher for dense credit assignment.

Key changes:
- Propagate tool result text as `logs["feedback"]` through MessageStepResult
  so SDPO's `extract_feedback()` picks it up automatically
- Add `build_teacher_prompt_from_messages()` for generic env support
  (code envs need tool schemas in the conversation prefix)
- Dispatch teacher prompt construction based on builder type
  (ProblemGroupBuilder for math, DeepcoderEnvGroupBuilder for code)
- Add `env=code_rl` CLI option with sandbox_backend config

Verified end-to-end: rollouts with Modal sandbox, code execution,
feedback extraction, teacher logprob scoring, and SDPO training step
all complete successfully on DeepCoder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds trajectory logging to the SDPO recipe to match GRPO's output:
- Logtree HTML + JSON reports for both eval and train iterations
- Rollout summary JSONL export for eval and train
- New config options: num_groups_to_log, rollout_json_export

This enables trajectory inspection during SDPO training runs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…logprobs

When the teacher prompt (which includes solution + feedback) plus response
tokens exceeds the model's context window, compute_teacher_logprobs now
truncates the response to fit. Positions beyond the truncation get
advantage=0, contributing no gradient.

Adds max_context_length config (default 32768) to control this.

Fixes: "Prompt length plus max_tokens exceeds the model's context window"
errors when using SDPO with models that have shorter context windows
(e.g. Qwen3-4B-Instruct-2507 with 32k context).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Logs three new metrics per SDPO training step:
- sdpo/teacher_truncated_count: number of trajectories truncated
- sdpo/teacher_truncated_frac: fraction of trained trajectories truncated
- sdpo/teacher_truncated_tokens_avg: avg tokens truncated when it happens

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…context

When teacher_prompt_len >= max_context_length, we previously returned
torch.zeros(len(response_tokens)). This created teacher_lps of all zeros,
making advantages = 0 - student_lp = negative for every position, which
incorrectly suppressed the student's behavior.

Now returns an empty tensor, so build_sdpo_datum pads all positions with
advantage=0 (no gradient), which is the correct behavior.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TinkerTokenCompleter and TinkerMessageCompleter now accept an optional
max_context_length parameter. When set, max_tokens is clamped so that
prompt_length + max_tokens <= max_context_length, preventing
BadRequestError crashes when long prompts leave insufficient room.

SDPO recipe passes max_context_length from its config to the completer.
Other recipes are unaffected (max_context_length defaults to None).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
write_rollout_summaries_jsonl_from_groups expects groups_P (a list of
RolloutSummaryGroup), not separate trajectory_groups_P/taglist_P args.
Switch to write_rollout_summaries_jsonl which accepts them directly.

Fixes TypeError crash after first rollout batch in SDPO training.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant