File tree Expand file tree Collapse file tree 1 file changed +13
-0
lines changed
Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -1288,6 +1288,19 @@ def compute_value_loss(
12881288
12891289
12901290def kl_penalty (logprob : torch .FloatTensor , ref_logprob : torch .FloatTensor , kl_penalty ) -> torch .FloatTensor :
1291+ """
1292+ The expectation of k1 and k3 estimator is the expectaed value of KL,
1293+ but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL!
1294+ On the other hand k2 estimator gives right gradient estimator,
1295+ so we use a straight through trick here
1296+ """
1297+ forward_score = kl_penalty_forward (logprob , ref_logprob , kl_penalty )
1298+ backward_score = 0.5 * (logprob - ref_logprob ).square ()
1299+
1300+ return backward_score - backward_score .detach () + forward_score .detach ()
1301+
1302+
1303+ def kl_penalty_forward (logprob : torch .FloatTensor , ref_logprob : torch .FloatTensor , kl_penalty ) -> torch .FloatTensor :
12911304 """Compute KL divergence given logprob and ref_logprob.
12921305 Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
12931306 See more description in http://joschu.net/blog/kl-approx.html
You can’t perform that action at this time.
0 commit comments