Skip to content

[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling#2953

Merged
zhaochenyang20 merged 17 commits intoverl-project:mainfrom
yaof20:truncated_importance_sampling
Aug 26, 2025
Merged

[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling#2953
zhaochenyang20 merged 17 commits intoverl-project:mainfrom
yaof20:truncated_importance_sampling

Conversation

@yaof20
Copy link
Contributor

@yaof20 yaof20 commented Aug 7, 2025

What does this PR do?

Support vLLM-FSDP off-policy importance sampling correction using Truncated Importance Sampling (TIS):

TIS

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=gae \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=1024 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
    actor_rollout_ref.model.enable_gradient_checkpointing=False \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=Qwen/Qwen2.5-32B-Instruct \
    critic.model.enable_gradient_checkpointing=False \
    critic.ppo_micro_batch_size_per_gpu=8 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_example' \
    trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=4 \
    trainer.save_freq=20 \
    trainer.test_freq=10 \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.calculate_log_probs=True \   # add this config to return rollout prob
    +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@   # add this config to set up C value in TIS

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

yaof20 and others added 3 commits August 6, 2025 09:30
@CLAassistant
Copy link

CLAassistant commented Aug 7, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Truncated Importance Sampling (TIS) to correct for the mismatch between the rollout policy and the training policy in PPO. The changes look good overall, correctly implementing the TIS logic by re-weighting the PPO loss with a clipped importance ratio. The necessary configuration options and data plumbing have been added. There is one potential issue in a deprecated function that could lead to a runtime error if TIS is enabled but rollout log probs are not provided.

@yaof20
Copy link
Contributor Author

yaof20 commented Aug 7, 2025

Hi VeRL Team,

This feature has been incorporated in OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/releases/tag/v0.8.9

Thanks!

@zdhNarsil
Copy link
Contributor

zdhNarsil commented Aug 7, 2025

Why $k_3$ is a biased KL estimator and how to fix it

In this PR, apart from our proposed truncated importance sampling, we also include a fix for KL estimator in this commit.

In John Schulman's blog, he introduces a KL estimator called '' $k_3$ estimator'' for $\text{KL}(p_\theta | p_{\text{base}})$:

$$k_3(x) = r(x) - 1 - \log r(x), \quad r(x) = \frac{p_{\text{base}}(x)}{p_\theta(x)}$$

This is a valid estimator because one can easily show:

$$𝔼_{p_\theta(x)}[k_3(x)] = \text{KL}(p_\theta | p_{\text{base}})$$

For example, the DeepSeek team uses this estimator in their reinforcement learning training to obtain the R1 model.

However, we point out that this provides invalid (i.e., biased) gradient signal for model training; in fact, one can show:

$$ 𝔼_{p_\theta(x)}\left[\nabla_\theta k_3(x)\right] = 𝔼_{p_\theta(x)}[\nabla_\theta (\frac{p_{\text{base}}(x)}{p_\theta(x)}) - \nabla_\theta(\log \frac{p_{\text{base}}(x)}{p_\theta(x)}) ] $$

$$ = 𝔼_{p_\theta(x)}\left[\nabla_\theta \left(\frac{p_{\text{base}}(x)}{p_\theta(x)}\right) \right] = 𝔼_{p_\theta(x)}\left[-\frac{p_{\text{base}}(x)}{p_\theta(x)} \nabla_\theta \log p_\theta(x) \right] $$

$$ \neq 𝔼_{p_\theta(x)}\left[-\log\frac{p_{\text{base}}(x)}{p_\theta(x)} \nabla_\theta \log p_\theta(x) \right] = \nabla_\theta\text{KL}(p_\theta | p_{\text{base}}) $$

To get an unbiased gradient estimate of $\text{KL}(p_\theta | p_{\text{base}})$, consider the $k_2$ estimator from the same blog:

$$k_2(x) = \frac{1}{2}\left(\log\frac{p_{\text{base}}(x)}{p_\theta(x)} \right)^2$$

$$\Rightarrow 𝔼_{p_\theta(x)}[\nabla_\theta k_2(x)] = 𝔼_{p_\theta(x)}\left[-\log\frac{p_{\text{base}}(x)}{p_\theta(x)} \nabla_\theta \log p_\theta(x) \right] = \nabla_\theta\text{KL}(p_\theta | p_{\text{base}})$$

How to fix

As a result, when we need the value of KL (e.g., to be added to the reward terms), we should use $k_1$ or $k_3$; when we need to compute the gradient of KL (e.g., as a training loss term), we should use $k_2$.

In our code commit, we use a straight-through estimator trick to implement this idea.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@wuxibin89
Copy link
Collaborator

Need vllm fix vllm-project/vllm#22387 to return final_logprobs

if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.imp_ratio_cap > 0:
select_keys.append("rollout_log_probs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server mode(agent loop) hasn't return rollout_log_probs for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have add a check here before adding rollout_log_probs.

Copy link
Collaborator

@PeterSH6 PeterSH6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you turn the default value to False and add a running script for TIS?

Copy link
Contributor Author

@yaof20 yaof20 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it now and added one running scripts :)

entropy_coeff: 0

# whether to apply the truncated Importance Sampling (-1 for no importance sampling)
imp_ratio_cap: 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed :)

@yaof20 yaof20 requested review from PeterSH6 and wuxibin89 August 7, 2025 08:45
loss_agg_mode: str = "token-mean",
rollout_log_probs=None,
imp_ratio_cap=-1,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

icic, I have included change in change in verl.trainer.ppo.core_algos.compute_policy_loss_vanilla as well.

Let me delete the change in the compute_policy_loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed now :)

@yaof20 yaof20 requested a review from eric-haibin-lin August 8, 2025 00:22
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
imp_ratio_cap=self.config.imp_ratio_cap,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to pass it as it's included in the config already

@yaof20 yaof20 requested a review from eric-haibin-lin August 8, 2025 06:46
@zdhNarsil
Copy link
Contributor

@jianfeng-Liu I do not quite understand your writing, but I guess you want to say

$$𝔼_{p_\theta(x)} \left[\nabla_\theta k_3(x)\right] = 𝔼_{p_\theta(x)}\left[-\frac{p_{\text{base}}(x)}{p_\theta(x)} \nabla_\theta \log p_\theta(x) + \nabla_\theta\log p_\theta(x)\right] $$
?

This is indeed correct; and one can further notice that $𝔼_{p_\theta(x)}\left[ \nabla_\theta\log p_\theta(x) \right] = \nabla_\theta \int p_\theta(x) dx = \nabla_\theta 1 = 0 $, you end up with $𝔼_{p_\theta(x)} \left[\nabla_\theta k_3(x)\right] = 𝔼_{p_\theta(x)}\left[-\frac{p_{\text{base}}(x)}{p_\theta(x)} \nabla_\theta \log p_\theta(x) \right] \neq \nabla_\theta\text{KL}(p_\theta | p_{\text{base}})$.

model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor problem: what if the user specify tis_imp_ratio_cap? but not specify calculate_log_probs = True? I would suggest directly checking whether "rollout_log_probs" is in model_inputs to avoid any risk of KeyNotFoundError (note this is already very deep in the whole verl codebase, so it could be hard for the user to debug)


# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
tis_imp_ratio_cap: -1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there should be some joint checking on the configuration? Like, if the user specifies tis_imp_ratio_cap, calculate_log_probs must be True?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Please verify in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we have already addressed this issue inverl/workers/actor/dp_actor.py as follows:

      if self.config.tis_imp_ratio_cap > 0:
          assert "rollout_log_probs" in data.batch.keys(), (
              "Truncated Importance Sampling (TIS) requires to configure "
              "`actor_rollout_ref.rollout.calculate_log_probs=True` "
              "and is not currently supported in Server mode (agent loop)."
          )
          select_keys.append("rollout_log_probs")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Looks good to me now.

@zwhe99
Copy link
Contributor

zwhe99 commented Oct 10, 2025

Why

𝑘
3
is a biased KL estimator and how to fix it
In this PR, apart from our proposed truncated importance sampling, we also include a fix for KL estimator in this commit.

In John Schulman's blog, he introduces a KL estimator called '' k 3 estimator'' for KL ( p θ | p base ) :

k 3 ( x ) = r ( x ) − 1 − log ⁡ r ( x ) , r ( x ) = p base ( x ) p θ ( x )

This is a valid estimator because one can easily show:

𝔼 p θ ( x ) [ k 3 ( x ) ] = KL ( p θ | p base )

For example, the DeepSeek team uses this estimator in their reinforcement learning training to obtain the R1 model.

However, we point out that this provides invalid (i.e., biased) gradient signal for model training; in fact, one can show:

𝔼 p θ ( x ) [ ∇ θ k 3 ( x ) ] = 𝔼 p θ ( x ) [ ∇ θ ( p base ( x ) p θ ( x ) ) − ∇ θ ( log ⁡ p base ( x ) p θ ( x ) ) ]

= 𝔼 p θ ( x ) [ ∇ θ ( p base ( x ) p θ ( x ) ) ] = 𝔼 p θ ( x ) [ − p base ( x ) p θ ( x ) ∇ θ log ⁡ p θ ( x ) ]

≠ 𝔼 p θ ( x ) [ − log ⁡ p base ( x ) p θ ( x ) ∇ θ log ⁡ p θ ( x ) ] = ∇ θ KL ( p θ | p base )

To get an unbiased gradient estimate of KL ( p θ | p base ) , consider the k 2 estimator from the same blog:

k 2 ( x ) = 1 2 ( log ⁡ p base ( x ) p θ ( x ) ) 2

⇒ 𝔼 p θ ( x ) [ ∇ θ k 2 ( x ) ] = 𝔼 p θ ( x ) [ − log ⁡ p base ( x ) p θ ( x ) ∇ θ log ⁡ p θ ( x ) ] = ∇ θ KL ( p θ | p base )

How to fix

As a result, when we need the value of KL (e.g., to be added to the reward terms), we should use k 1 or k 3 ; when we need to compute the gradient of KL (e.g., as a training loss term), we should use k 2 .

In our code commit, we use a straight-through estimator trick to implement this idea.

@zdhNarsil Will the gradients of K2 exhibit high variance, and will this affect training stability?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.