-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathrun_hw3_actor_critic.py
More file actions
126 lines (92 loc) · 4.54 KB
/
run_hw3_actor_critic.py
File metadata and controls
126 lines (92 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
from cs285.agents.ac_agent import ACAgent
from cs285.infrastructure.rl_trainer import RL_Trainer
class AC_Trainer(object):
def __init__(self, params):
#####################
## SET AGENT PARAMS
#####################
computation_graph_args = {
'n_layers': params['n_layers'],
'size': params['size'],
'learning_rate': params['learning_rate'],
'num_target_updates': params['num_target_updates'],
'num_grad_steps_per_target_update': params['num_grad_steps_per_target_update'],
}
estimate_advantage_args = {
'gamma': params['discount'],
'standardize_advantages': not(params['dont_standardize_advantages']),
}
train_args = {
'num_agent_train_steps_per_iter': params['num_agent_train_steps_per_iter'],
'num_critic_updates_per_agent_update': params['num_critic_updates_per_agent_update'],
'num_actor_updates_per_agent_update': params['num_actor_updates_per_agent_update'],
}
agent_params = {**computation_graph_args, **estimate_advantage_args, **train_args}
self.params = params
self.params['agent_class'] = ACAgent
self.params['agent_params'] = agent_params
self.params['batch_size_initial'] = self.params['batch_size']
################
## RL TRAINER
################
self.rl_trainer = RL_Trainer(self.params)
def run_training_loop(self):
self.rl_trainer.run_training_loop(
self.params['n_iter'],
collect_policy = self.rl_trainer.agent.actor,
eval_policy = self.rl_trainer.agent.actor,
)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='CartPole-v0')
parser.add_argument('--ep_len', type=int, default=200)
parser.add_argument('--exp_name', type=str, default='todo')
parser.add_argument('--n_iter', '-n', type=int, default=200)
parser.add_argument('--num_agent_train_steps_per_iter', type=int, default=1)
parser.add_argument('--num_critic_updates_per_agent_update', type=int, default=1)
parser.add_argument('--num_actor_updates_per_agent_update', type=int, default=1)
parser.add_argument('--batch_size', '-b', type=int, default=1000) #steps collected per train iteration
parser.add_argument('--eval_batch_size', '-eb', type=int, default=400) #steps collected per eval iteration
parser.add_argument('--train_batch_size', '-tb', type=int, default=1000) ##steps used per gradient step
parser.add_argument('--discount', type=float, default=1.0)
parser.add_argument('--learning_rate', '-lr', type=float, default=5e-3)
parser.add_argument('--dont_standardize_advantages', '-dsa', action='store_true')
parser.add_argument('--num_target_updates', '-ntu', type=int, default=10)
parser.add_argument('--num_grad_steps_per_target_update', '-ngsptu', type=int, default=10)
parser.add_argument('--n_layers', '-l', type=int, default=2)
parser.add_argument('--size', '-s', type=int, default=64)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--no_gpu', '-ngpu', action='store_true')
parser.add_argument('--which_gpu', '-gpu_id', default=0)
parser.add_argument('--video_log_freq', type=int, default=-1)
parser.add_argument('--scalar_log_freq', type=int, default=10)
parser.add_argument('--save_params', action='store_true')
args = parser.parse_args()
# convert to dictionary
params = vars(args)
# for policy gradient, we made a design decision
# to force batch_size = train_batch_size
# note that, to avoid confusion, you don't even have a train_batch_size argument anymore (above)
params['train_batch_size'] = params['batch_size']
##################################
### CREATE DIRECTORY FOR LOGGING
##################################
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data')
if not (os.path.exists(data_path)):
os.makedirs(data_path)
logdir = 'hw3_ ' + args.exp_name + '_' + args.env_name + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
logdir = os.path.join(data_path, logdir)
params['logdir'] = logdir
if not(os.path.exists(logdir)):
os.makedirs(logdir)
print("\n\n\nLOGGING TO: ", logdir, "\n\n\n")
###################
### RUN TRAINING
###################
trainer = AC_Trainer(params)
trainer.run_training_loop()
if __name__ == "__main__":
main()