[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling#2953
Conversation
straight through trick for kl gradient estimation
There was a problem hiding this comment.
Code Review
This pull request introduces Truncated Importance Sampling (TIS) to correct for the mismatch between the rollout policy and the training policy in PPO. The changes look good overall, correctly implementing the TIS logic by re-weighting the PPO loss with a clipped importance ratio. The necessary configuration options and data plumbing have been added. There is one potential issue in a deprecated function that could lead to a runtime error if TIS is enabled but rollout log probs are not provided.
|
Hi VeRL Team, This feature has been incorporated in OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/releases/tag/v0.8.9 Thanks! |
Why
|
|
Need vllm fix vllm-project/vllm#22387 to return |
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.imp_ratio_cap > 0: | ||
| select_keys.append("rollout_log_probs") |
There was a problem hiding this comment.
Server mode(agent loop) hasn't return rollout_log_probs for now.
There was a problem hiding this comment.
Hi, I have add a check here before adding rollout_log_probs.
| # support logging rollout prob for debugging purpose | ||
| calculate_log_probs: False | ||
| # "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling | ||
| calculate_log_probs: True |
There was a problem hiding this comment.
Could you turn the default value to False and add a running script for TIS?
There was a problem hiding this comment.
I fixed it now and added one running scripts :)
verl/trainer/config/actor/actor.yaml
Outdated
| entropy_coeff: 0 | ||
|
|
||
| # whether to apply the truncated Importance Sampling (-1 for no importance sampling) | ||
| imp_ratio_cap: 5 |
| loss_agg_mode: str = "token-mean", | ||
| rollout_log_probs=None, | ||
| imp_ratio_cap=-1, | ||
| ): |
There was a problem hiding this comment.
this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead
There was a problem hiding this comment.
icic, I have included change in change in verl.trainer.ppo.core_algos.compute_policy_loss_vanilla as well.
Let me delete the change in the compute_policy_loss.
verl/workers/actor/dp_actor.py
Outdated
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| imp_ratio_cap=self.config.imp_ratio_cap, |
There was a problem hiding this comment.
no need to pass it as it's included in the config already
|
@jianfeng-Liu I do not quite understand your writing, but I guess you want to say
This is indeed correct; and one can further notice that |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None |
There was a problem hiding this comment.
A very minor problem: what if the user specify tis_imp_ratio_cap? but not specify calculate_log_probs = True? I would suggest directly checking whether "rollout_log_probs" is in model_inputs to avoid any risk of KeyNotFoundError (note this is already very deep in the whole verl codebase, so it could be hard for the user to debug)
|
|
||
| # Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl | ||
| # the truncation value C of truncated Importance Sampling (-1 for disable TIS) | ||
| tis_imp_ratio_cap: -1 |
There was a problem hiding this comment.
Do you think there should be some joint checking on the configuration? Like, if the user specifies tis_imp_ratio_cap, calculate_log_probs must be True?
There was a problem hiding this comment.
Agreed. Please verify in the config
There was a problem hiding this comment.
Hi, we have already addressed this issue inverl/workers/actor/dp_actor.py as follows:
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
select_keys.append("rollout_log_probs")
There was a problem hiding this comment.
I see, thanks! Looks good to me now.
@zdhNarsil Will the gradients of K2 exhibit high variance, and will this affect training stability? |
What does this PR do?
Support vLLM-FSDP off-policy importance sampling correction using Truncated Importance Sampling (TIS):
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)