Skip to content

Commit 2dd6ec5

Browse files
committed
add ppo without gae
1 parent d579b3d commit 2dd6ec5

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

ppo_discrete.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Proximal Policy Optimization for discrete (action space) environments, without GAE.
3+
"""
4+
import gym
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
from torch.distributions import Categorical
10+
11+
#Hyperparameters
12+
learning_rate = 0.0005
13+
gamma = 0.98
14+
lmbda = 0.95
15+
eps_clip = 0.1
16+
K_epoch = 3
17+
T_horizon = 20
18+
19+
class PPO(nn.Module):
20+
def __init__(self, state_dim, action_dim):
21+
super(PPO, self).__init__()
22+
self.data = []
23+
24+
self.fc1 = nn.Linear(state_dim,256)
25+
self.fc_pi = nn.Linear(256,action_dim)
26+
self.fc_v = nn.Linear(256,1)
27+
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
28+
29+
def pi(self, x, softmax_dim = 0):
30+
x = F.relu(self.fc1(x))
31+
x = self.fc_pi(x)
32+
prob = F.softmax(x, dim=softmax_dim)
33+
return prob
34+
35+
def v(self, x):
36+
x = F.relu(self.fc1(x))
37+
v = self.fc_v(x)
38+
return v
39+
40+
def put_data(self, transition):
41+
self.data.append(transition)
42+
43+
def make_batch(self):
44+
s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
45+
for transition in self.data:
46+
s, a, r, s_prime, prob_a, done = transition
47+
48+
s_lst.append(s)
49+
a_lst.append([a])
50+
r_lst.append([r])
51+
s_prime_lst.append(s_prime)
52+
prob_a_lst.append([prob_a])
53+
done_mask = 0 if done else 1
54+
done_lst.append([done_mask])
55+
56+
s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
57+
torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
58+
torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
59+
self.data = []
60+
return s, a, r, s_prime, done_mask, prob_a
61+
62+
def train_net(self):
63+
s, a, r, s_prime, done_mask, prob_a = self.make_batch()
64+
65+
for i in range(K_epoch):
66+
td_target = r + gamma * self.v(s_prime) * done_mask
67+
delta = td_target - self.v(s)
68+
delta = delta.detach().numpy()
69+
70+
advantage_lst = []
71+
advantage = 0.0
72+
for delta_t in delta[::-1]:
73+
advantage = gamma * lmbda * advantage + delta_t[0]
74+
advantage_lst.append([advantage])
75+
advantage_lst.reverse()
76+
advantage = torch.tensor(advantage_lst, dtype=torch.float)
77+
78+
pi = self.pi(s, softmax_dim=-1)
79+
dist_entropy = Categorical(pi).entropy()
80+
pi_a = pi.gather(1,a)
81+
ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) # a/b == exp(log(a)-log(b))
82+
83+
surr1 = ratio * advantage
84+
surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
85+
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach()) - 0.01*dist_entropy
86+
87+
self.optimizer.zero_grad()
88+
loss.mean().backward()
89+
self.optimizer.step()
90+
91+
def main():
92+
env = gym.make('CartPole-v1')
93+
state_dim = env.observation_space.shape[0]
94+
action_dim = env.action_space.n # discrete
95+
model = PPO(state_dim, action_dim)
96+
score = 0.0
97+
epi_len = []
98+
print_interval = 20
99+
100+
for n_epi in range(10000):
101+
s = env.reset()
102+
done = False
103+
while not done:
104+
for t in range(T_horizon):
105+
prob = model.pi(torch.from_numpy(s).float())
106+
m = Categorical(prob)
107+
a = m.sample().item()
108+
s_prime, r, done, info = env.step(a)
109+
# env.render()
110+
model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done))
111+
112+
s = s_prime
113+
114+
score += r
115+
if done:
116+
break
117+
118+
model.train_net()
119+
epi_len.append(t)
120+
if n_epi%print_interval==0 and n_epi!=0:
121+
print("# of episode :{}, avg score : {:.1f}, avg epi length :{}".format(n_epi, score/print_interval, int(np.mean(epi_len)))
122+
score = 0.0
123+
124+
env.close()
125+
126+
if __name__ == '__main__':
127+
main()

ppo_gae_discrete.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Proximal Policy Optimization for discrete (action space) environments, via the Generalized Advantage Estimation (GAE).
3+
"""
14
import gym
25
import torch
36
import torch.nn as nn

0 commit comments

Comments
 (0)