-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathdqn_agent.py
More file actions
107 lines (85 loc) · 4.25 KB
/
dqn_agent.py
File metadata and controls
107 lines (85 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
from cs285.infrastructure.dqn_utils import MemoryOptimizedReplayBuffer, PiecewiseSchedule
from cs285.policies.argmax_policy import ArgMaxPolicy
from cs285.critics.dqn_critic import DQNCritic
class DQNAgent(object):
def __init__(self, env, agent_params):
self.env = env
self.agent_params = agent_params
self.batch_size = agent_params['batch_size']
# import ipdb; ipdb.set_trace()
self.last_obs = self.env.reset()
self.num_actions = agent_params['ac_dim']
self.learning_starts = agent_params['learning_starts']
self.learning_freq = agent_params['learning_freq']
self.target_update_freq = agent_params['target_update_freq']
self.replay_buffer_idx = None
self.exploration = agent_params['exploration_schedule']
self.optimizer_spec = agent_params['optimizer_spec']
self.critic = DQNCritic(agent_params, self.optimizer_spec)
self.actor = ArgMaxPolicy(self.critic)
lander = agent_params['env_name'].startswith('LunarLander')
self.replay_buffer = MemoryOptimizedReplayBuffer(
agent_params['replay_buffer_size'], agent_params['frame_history_len'], lander=lander)
self.t = 0
self.num_param_updates = 0
def add_to_replay_buffer(self, paths):
pass
def step_env(self):
"""
Step the env and store the transition
At the end of this block of code, the simulator should have been
advanced one step, and the replay buffer should contain one more transition.
Note that self.last_obs must always point to the new latest observation.
"""
# TODO store the latest observation ("frame") into the replay buffer
# HINT: the replay buffer used here is `MemoryOptimizedReplayBuffer`
# in dqn_utils.py
self.replay_buffer_idx = TODO
eps = self.exploration.value(self.t)
# TODO use epsilon greedy exploration when selecting action
perform_random_action = TODO
if perform_random_action:
# HINT: take random action
# with probability eps (see np.random.random())
# OR if your current step number (see self.t) is less that self.learning_starts
action = TODO
else:
# HINT: Your actor will take in multiple previous observations ("frames") in order
# to deal with the partial observability of the environment. Get the most recent
# `frame_history_len` observations using functionality from the replay buffer,
# and then use those observations as input to your actor.
action = TODO
# TODO take a step in the environment using the action from the policy
# HINT1: remember that self.last_obs must always point to the newest/latest observation
# HINT2: remember the following useful function that you've seen before:
#obs, reward, done, info = env.step(action)
TODO
# TODO store the result of taking this action into the replay buffer
# HINT1: see your replay buffer's `store_effect` function
# HINT2: one of the arguments you'll need to pass in is self.replay_buffer_idx from above
TODO
# TODO if taking this step resulted in done, reset the env (and the latest observation)
TODO
def sample(self, batch_size):
if self.replay_buffer.can_sample(self.batch_size):
return self.replay_buffer.sample(batch_size)
else:
return [],[],[],[],[]
def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
log = {}
if (self.t > self.learning_starts
and self.t % self.learning_freq == 0
and self.replay_buffer.can_sample(self.batch_size)
):
# TODO fill in the call to the update function using the appropriate tensors
log = self.critic.update(
TODO
)
# TODO update the target network periodically
# HINT: your critic already has this functionality implemented
if self.num_param_updates % self.target_update_freq == 0:
TODO
self.num_param_updates += 1
self.t += 1
return log