Skip to content

Commit 3d98e74

Browse files
authored
Update core_algos.py
straight through trick for kl gradient estimation
1 parent 5e2181b commit 3d98e74

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,19 @@ def compute_value_loss(
12881288

12891289

12901290
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
1291+
"""
1292+
The expectation of k1 and k3 estimator is the expectaed value of KL,
1293+
but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL!
1294+
On the other hand k2 estimator gives right gradient estimator,
1295+
so we use a straight through trick here
1296+
"""
1297+
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
1298+
backward_score = 0.5 * (logprob - ref_logprob).square()
1299+
1300+
return backward_score - backward_score.detach() + forward_score.detach()
1301+
1302+
1303+
def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
12911304
"""Compute KL divergence given logprob and ref_logprob.
12921305
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
12931306
See more description in http://joschu.net/blog/kl-approx.html

0 commit comments

Comments
 (0)