Skip to content

Commit 99a0fcc

Browse files
committed
add entropy bonus for ppo gae discrete
1 parent d11a5a1 commit 99a0fcc

File tree

6 files changed

+7
-5
lines changed

6 files changed

+7
-5
lines changed
54 Bytes
Binary file not shown.

model/rdpg_policy

0 Bytes
Binary file not shown.

model/rdpg_q

0 Bytes
Binary file not shown.

model/rdpg_target_q

0 Bytes
Binary file not shown.

ppo_gae_discrete.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ def train_net(self):
7272
advantage_lst.reverse()
7373
advantage = torch.tensor(advantage_lst, dtype=torch.float)
7474

75-
pi = self.pi(s, softmax_dim=1)
75+
pi = self.pi(s, softmax_dim=-1)
76+
dist_entropy = Categorical(pi).entropy()
7677
pi_a = pi.gather(1,a)
7778
ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) # a/b == exp(log(a)-log(b))
7879

7980
surr1 = ratio * advantage
8081
surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
81-
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())
82+
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach()) - 0.01*dist_entropy
8283

8384
self.optimizer.zero_grad()
8485
loss.mean().backward()
@@ -90,6 +91,7 @@ def main():
9091
action_dim = env.action_space.n # discrete
9192
model = PPO(state_dim, action_dim)
9293
score = 0.0
94+
epi_len = []
9395
print_interval = 20
9496

9597
for n_epi in range(10000):
@@ -111,9 +113,9 @@ def main():
111113
break
112114

113115
model.train_net()
114-
116+
epi_len.append(t)
115117
if n_epi%print_interval==0 and n_epi!=0:
116-
print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
118+
print("# of episode :{}, avg score : {:.1f}, avg epi length :{}".format(n_epi, score/print_interval, int(np.mean(epi_len)))
117119
score = 0.0
118120

119121
env.close()

sac_v2_lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def plot(rewards):
196196
replay_buffer = ReplayBufferLSTM2(replay_buffer_size)
197197

198198
# choose env
199-
ENV = ['Reacher', 'Pendulum-v0', 'HalfCheetah-v2'][2]
199+
ENV = ['Reacher', 'Pendulum-v0', 'HalfCheetah-v2'][1]
200200
if ENV == 'Reacher':
201201
NUM_JOINTS=2
202202
LINK_LENGTH=[200, 140]

0 commit comments

Comments
 (0)