-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathgrammar_vae.py
More file actions
91 lines (73 loc) · 3.03 KB
/
grammar_vae.py
File metadata and controls
91 lines (73 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import torch.utils.data
import torch.optim as optim
from torch.autograd import Variable
from model import GrammarVariationalAutoEncoder, VAELoss
from visdom_helper.visdom_helper import Dashboard
class Session():
def __init__(self, model, train_step_init=0, lr=1e-3, is_cuda=False):
self.train_step = train_step_init
self.model = model
self.optimizer = optim.Adam(model.parameters(), lr=lr)
self.loss_fn = VAELoss()
self.dashboard = Dashboard('Grammar-Variational-Autoencoder-experiment')
def train(self, loader, epoch_number):
# built-in method for the nn.module, sets a training flag.
self.model.train()
_losses = []
for batch_idx, data in enumerate(loader):
# have to cast data to FloatTensor. DoubleTensor errors with Conv1D
data = Variable(data)
# do not use CUDA atm
self.optimizer.zero_grad()
recon_batch, mu, log_var = self.model(data)
loss = self.loss_fn(data, mu, log_var, recon_batch)
_losses.append(loss.numpy())
loss.backward()
self.optimizer.step()
self.train_step += 1
loss_value = loss.data.numpy()
batch_size = len(data)
self.dashboard.append('training_loss', 'line',
X=np.array([self.train_step]),
Y=loss_value / batch_size)
if batch_idx == 0:
print('batch size', batch_size)
if batch_idx % 40 == 0:
print('training loss: {:.4f}'.format(loss_value[0] / batch_size))
return _losses
def test(self, loader):
# nn.Module method, sets the training flag to False
self.model.eval()
test_loss = 0
for batch_idx, data in enumerate(loader):
data = Variable(data, volatile=True)
# do not use CUDA atm
recon_batch, mu, log_var = self.model(data)
test_loss += self.loss_fn(data, mu, log_var, recon_batch).data[0]
test_loss /= len(test_loader.dataset)
print('testset length', len(test_loader.dataset))
print('====> Test set loss: {:.4f}'.format(test_loss))
EPOCHS = 20
BATCH_SIZE = 200
import h5py
def kfold_loader(k, s, e=None):
if not e:
e = k
with h5py.File('data/eq2_grammar_dataset.h5', 'r') as h5f:
result = np.concatenate([h5f['data'][i::k] for i in range(s, e)])
return torch.FloatTensor(result)
train_loader = torch.utils.data \
.DataLoader(kfold_loader(10, 1),
batch_size=BATCH_SIZE, shuffle=False)
# todo: need to have separate training and validation set
test_loader = torch.utils \
.data.DataLoader(kfold_loader(10, 0, 1),
batch_size=BATCH_SIZE, shuffle=False)
losses = []
vae = GrammarVariationalAutoEncoder()
sess = Session(vae, lr=2e-3)
for epoch in range(1, EPOCHS + 1):
losses += sess.train(train_loader, epoch)
print('epoch {} complete'.format(epoch))
sess.test(test_loader)