Skip to content

Commit 4f19354

Browse files
committed
modify rdpg
1 parent fd20263 commit 4f19354

File tree

9 files changed

+10
-10
lines changed

9 files changed

+10
-10
lines changed
-18 Bytes
Binary file not shown.

common/policy_networks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def __init__(self, state_space, action_space, hidden_dim, action_range=1., init_
104104
self.linear4 = nn.Linear(hidden_dim, self._action_dim) # output dim = dim of action
105105

106106
# weights initialization
107-
self.linear3.weight.data.uniform_(-init_w, init_w)
108-
self.linear3.bias.data.uniform_(-init_w, init_w)
107+
self.linear4.weight.data.uniform_(-init_w, init_w)
108+
self.linear4.bias.data.uniform_(-init_w, init_w)
109109

110110

111111
def forward(self, state, last_action, hidden_in):
@@ -127,7 +127,7 @@ def forward(self, state, last_action, hidden_in):
127127
# merged
128128
merged_branch=torch.cat([fc_branch, lstm_branch], -1)
129129
x = activation(self.linear3(merged_branch))
130-
x = F.tanh(self.linear4(x)).clone()
130+
x = F.tanh(self.linear4(x))
131131
x = x.permute(1,0,2) # permute back
132132

133133
return x, lstm_hidden # lstm_hidden is actually tuple: (hidden, cell)
@@ -196,7 +196,7 @@ def forward(self, state, last_action, hidden_in):
196196
# hidden only for initialization, later on hidden states are passed automatically for sequential data
197197
x, lstm_hidden = self.lstm1(x, hidden_in) # no activation after lstm
198198
x = activation(self.linear2(x))
199-
x = F.tanh(self.linear3(x)).clone()
199+
x = F.tanh(self.linear3(x))
200200
x = x.permute(1,0,2) # permute back
201201

202202
return x, lstm_hidden # lstm_hidden is actually tuple: (hidden, cell)

ddpg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def _reverse_action(self, action):
258258
action_dim = env.num_actions
259259
state_dim = env.num_observations
260260
elif ENV == 'Pendulum':
261-
# env = NormalizedActions(gym.make("Pendulum-v0"))
262-
env = gym.make("Pendulum-v0")
261+
env = NormalizedActions(gym.make("Pendulum-v0"))
262+
# env = gym.make("Pendulum-v0")
263263
action_dim = env.action_space.shape[0]
264264
state_dim = env.observation_space.shape[0]
265265
elif ENV == 'HalfCheetah':

ddpg_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def _reverse_action(self, action):
165165
state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(env.num_observations, ))
166166

167167
elif ENV == 'Pendulum':
168-
# env = NormalizedActions(gym.make("Pendulum-v0"))
169-
env = gym.make("Pendulum-v0")
168+
env = NormalizedActions(gym.make("Pendulum-v0"))
169+
# env = gym.make("Pendulum-v0")
170170
action_space = env.action_space
171171
state_space = env.observation_space
172172
hidden_dim = 64

model/rdpg_policy

0 Bytes
Binary file not shown.

model/rdpg_q

0 Bytes
Binary file not shown.

model/rdpg_target_q

0 Bytes
Binary file not shown.

rdpg.png

11 KB
Loading

rdpg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def _reverse_action(self, action):
179179
state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(env.num_observations, ))
180180

181181
elif ENV == 'Pendulum':
182-
# env = NormalizedActions(gym.make("Pendulum-v0"))
183-
env = gym.make("Pendulum-v0")
182+
env = NormalizedActions(gym.make("Pendulum-v0"))
183+
# env = gym.make("Pendulum-v0")
184184
action_space = env.action_space
185185
state_space = env.observation_space
186186
hidden_dim = 64

0 commit comments

Comments
 (0)