-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathac_agent.py
More file actions
66 lines (50 loc) · 2.3 KB
/
ac_agent.py
File metadata and controls
66 lines (50 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from collections import OrderedDict
from cs285.critics.bootstrapped_continuous_critic import \
BootstrappedContinuousCritic
from cs285.infrastructure.replay_buffer import ReplayBuffer
from cs285.infrastructure.utils import *
from cs285.policies.MLP_policy import MLPPolicyAC
from .base_agent import BaseAgent
class ACAgent(BaseAgent):
def __init__(self, env, agent_params):
super(ACAgent, self).__init__()
self.env = env
self.agent_params = agent_params
self.gamma = self.agent_params['gamma']
self.standardize_advantages = self.agent_params['standardize_advantages']
self.actor = MLPPolicyAC(
self.agent_params['ac_dim'],
self.agent_params['ob_dim'],
self.agent_params['n_layers'],
self.agent_params['size'],
self.agent_params['discrete'],
self.agent_params['learning_rate'],
)
self.critic = BootstrappedContinuousCritic(self.agent_params)
self.replay_buffer = ReplayBuffer()
def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
# TODO Implement the following pseudocode:
# for agent_params['num_critic_updates_per_agent_update'] steps,
# update the critic
# advantage = estimate_advantage(...)
# for agent_params['num_actor_updates_per_agent_update'] steps,
# update the actor
loss = OrderedDict()
loss['Critic_Loss'] = TODO
loss['Actor_Loss'] = TODO
return loss
def estimate_advantage(self, ob_no, next_ob_no, re_n, terminal_n):
# TODO Implement the following pseudocode:
# 1) query the critic with ob_no, to get V(s)
# 2) query the critic with next_ob_no, to get V(s')
# 3) estimate the Q value as Q(s, a) = r(s, a) + gamma*V(s')
# HINT: Remember to cut off the V(s') term (ie set it to 0) at terminal states (ie terminal_n=1)
# 4) calculate advantage (adv_n) as A(s, a) = Q(s, a) - V(s)
adv_n = TODO
if self.standardize_advantages:
adv_n = (adv_n - np.mean(adv_n)) / (np.std(adv_n) + 1e-8)
return adv_n
def add_to_replay_buffer(self, paths):
self.replay_buffer.add_rollouts(paths)
def sample(self, batch_size):
return self.replay_buffer.sample_recent_data(batch_size)