Skip to content

Commit 577e61d

Browse files
committed
update ppo discrete
1 parent 4f19354 commit 577e61d

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

ppo_continuous_multiprocess2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
##################### hyper parameters ####################
5252

53-
ENV_NAME = 'Pendulum-v0' # environment name
53+
ENV_NAME = 'LunarLanderContinuous-v2' # environment name: LunarLander-v2, Pendulum-v0
5454
RANDOMSEED = 2 # random seed
5555

5656
EP_MAX = 1000 # total number of episodes for training
@@ -63,7 +63,7 @@
6363
C_UPDATE_STEPS = 10 # critic update steps
6464
EPS = 1e-8 # numerical residual
6565
MODEL_PATH = 'model/ppo_multi'
66-
NUM_WORKERS=2 # or: mp.cpu_count()
66+
NUM_WORKERS=1 # or: mp.cpu_count()
6767
ACTION_RANGE = 2. # if unnormalized, normalized action range should be 1.
6868
METHOD = [
6969
dict(name='kl_pen', kl_target=0.01, lam=0.5), # KL penalty

ppo_gae_discrete.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
T_horizon = 20
1515

1616
class PPO(nn.Module):
17-
def __init__(self):
17+
def __init__(self, state_dim, action_dim):
1818
super(PPO, self).__init__()
1919
self.data = []
2020

21-
self.fc1 = nn.Linear(4,256)
22-
self.fc_pi = nn.Linear(256,2)
21+
self.fc1 = nn.Linear(state_dim,256)
22+
self.fc_pi = nn.Linear(256,action_dim)
2323
self.fc_v = nn.Linear(256,1)
2424
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
2525

@@ -86,7 +86,9 @@ def train_net(self):
8686

8787
def main():
8888
env = gym.make('CartPole-v1')
89-
model = PPO()
89+
state_dim = env.observation_space.shape[0]
90+
action_dim = env.action_space.n # discrete
91+
model = PPO(state_dim, action_dim)
9092
score = 0.0
9193
print_interval = 20
9294

@@ -99,8 +101,9 @@ def main():
99101
m = Categorical(prob)
100102
a = m.sample().item()
101103
s_prime, r, done, info = env.step(a)
102-
104+
# env.render()
103105
model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done))
106+
104107
s = s_prime
105108

106109
score += r

0 commit comments

Comments
 (0)