-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathMLP_policy.py
More file actions
114 lines (96 loc) · 3.72 KB
/
MLP_policy.py
File metadata and controls
114 lines (96 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import abc
import itertools
from torch import nn
from torch.nn import functional as F
from torch import optim
import numpy as np
import torch
from torch import distributions
from cs285.infrastructure import pytorch_util as ptu
from cs285.policies.base_policy import BasePolicy
class MLPPolicy(BasePolicy, nn.Module, metaclass=abc.ABCMeta):
def __init__(self,
ac_dim,
ob_dim,
n_layers,
size,
discrete=False,
learning_rate=1e-4,
training=True,
nn_baseline=False,
**kwargs
):
super().__init__(**kwargs)
# init vars
self.ac_dim = ac_dim
self.ob_dim = ob_dim
self.n_layers = n_layers
self.discrete = discrete
self.size = size
self.learning_rate = learning_rate
self.training = training
self.nn_baseline = nn_baseline
if self.discrete:
self.logits_na = ptu.build_mlp(input_size=self.ob_dim,
output_size=self.ac_dim,
n_layers=self.n_layers,
size=self.size)
self.logits_na.to(ptu.device)
self.mean_net = None
self.logstd = None
self.optimizer = optim.Adam(self.logits_na.parameters(),
self.learning_rate)
else:
self.logits_na = None
self.mean_net = ptu.build_mlp(input_size=self.ob_dim,
output_size=self.ac_dim,
n_layers=self.n_layers, size=self.size)
self.logstd = nn.Parameter(
torch.zeros(self.ac_dim, dtype=torch.float32, device=ptu.device)
)
self.mean_net.to(ptu.device)
self.logstd.to(ptu.device)
self.optimizer = optim.Adam(
itertools.chain([self.logstd], self.mean_net.parameters()),
self.learning_rate
)
if nn_baseline:
self.baseline = ptu.build_mlp(
input_size=self.ob_dim,
output_size=1,
n_layers=self.n_layers,
size=self.size,
)
self.baseline.to(ptu.device)
self.baseline_optimizer = optim.Adam(
self.baseline.parameters(),
self.learning_rate,
)
else:
self.baseline = None
##################################
def save(self, filepath):
torch.save(self.state_dict(), filepath)
##################################
# query the policy with observation(s) to get selected action(s)
def get_action(self, obs: np.ndarray) -> np.ndarray:
# TODO: get this from Piazza
return action
# update/train this policy
def update(self, observations, actions, **kwargs):
raise NotImplementedError
# This function defines the forward pass of the network.
# You can return anything you want, but you should be able to differentiate
# through it. For example, you can return a torch.FloatTensor. You can also
# return more flexible objects, such as a
# `torch.distributions.Distribution` object. It's up to you!
def forward(self, observation: torch.FloatTensor):
# TODO: get this from Piazza
return action_distribution
#####################################################
#####################################################
class MLPPolicyAC(MLPPolicy):
def update(self, observations, actions, adv_n=None):
# TODO: update the policy and return the loss
loss = TODO
return loss.item()