Skip to content
This repository was archived by the owner on Jan 26, 2024. It is now read-only.

Commit d78e4af

Browse files
committed
Improve training script
1 parent 9e03111 commit d78e4af

File tree

2 files changed

+129
-63
lines changed

2 files changed

+129
-63
lines changed

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
'window_size': 200,
2121
'stride_size': 10,
2222
'use_transposition': False,
23+
'control_ratio': 1.0
2324
}

train.py

Lines changed: 128 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,119 +4,179 @@
44
from torch.autograd import Variable
55

66
import numpy as np
7-
import sys, os, time, optparse
7+
8+
import os
9+
import sys
10+
import time
11+
import optparse
812

913
from tensorboardX import SummaryWriter
1014

11-
import config, utils
15+
import utils
16+
import config
17+
from data import Dataset
1218
from model import PerformanceRNN
1319
from sequence import NoteSeq, EventSeq, ControlSeq
14-
from data import Dataset
1520

1621
# pylint: disable=E1102
1722
# pylint: disable=E1101
1823

24+
#========================================================================
25+
# Settings
1926
#========================================================================
2027

2128
def get_options():
2229
parser = optparse.OptionParser()
23-
parser.add_option('-s',
24-
dest='save_path',
25-
type='string',
26-
default='train.sess')
27-
parser.add_option('-d',
28-
dest='data_path',
29-
type='string',
30-
default='dataset/processed/')
31-
parser.add_option('-i',
32-
dest='saving_interval',
33-
type='float',
34-
default=60.)
30+
31+
parser.add_option('-s', '--session',
32+
dest='sess_path',
33+
type='string',
34+
default='train.sess')
35+
36+
parser.add_option('-d', '--dataset',
37+
dest='data_path',
38+
type='string',
39+
default='dataset/processed/')
40+
41+
parser.add_option('-i', '--saving-interval',
42+
dest='saving_interval',
43+
type='float',
44+
default=60.)
45+
46+
parser.add_option('-b', '--batch-size',
47+
dest='batch_size',
48+
type='int',
49+
default=config.train['batch_size'])
50+
51+
parser.add_option('-l', '--learning-rate',
52+
dest='learning_rate',
53+
type='float',
54+
default=config.train['learning_rate'])
55+
56+
parser.add_option('-w', '--window-size',
57+
dest='window_size',
58+
type='int',
59+
default=config.train['window_size'])
60+
61+
parser.add_option('-z', '--stride-size',
62+
dest='stride_size',
63+
type='int',
64+
default=config.train['stride_size'])
65+
66+
parser.add_option('-c', '--control-ratio',
67+
dest='control_ratio',
68+
type='float',
69+
default=config.train['control_ratio'])
70+
71+
parser.add_option('-t', '--use-transposition',
72+
dest='use_transposition',
73+
action='store_true',
74+
default=config.train['use_transposition'])
75+
76+
parser.add_option('-p', '--model-params',
77+
dest='model_params',
78+
type='string',
79+
default='')
80+
81+
parser.add_option('-r', '--reset-optimizer',
82+
dest='reset_optimizer',
83+
action='store_true',
84+
default=False)
85+
3586
return parser.parse_args()[0]
3687

3788
options = get_options()
3889

39-
#========================================================================
90+
#------------------------------------------------------------------------
4091

41-
save_path = options.save_path
92+
sess_path = options.sess_path
4293
data_path = options.data_path
4394
saving_interval = options.saving_interval
4495

96+
learning_rate = options.learning_rate
97+
batch_size = options.batch_size
98+
window_size = options.window_size
99+
stride_size = options.stride_size
100+
use_transposition = options.use_transposition
101+
control_ratio = options.control_ratio
102+
reset_optimizer = options.reset_optimizer
103+
45104
event_dim = EventSeq.dim()
46105
control_dim = ControlSeq.dim()
47106
model_config = config.model
48-
learning_rate = config.train['learning_rate']
49-
batch_size = config.train['batch_size']
50-
window_size = config.train['window_size']
51-
stride_size = config.train['stride_size']
52-
use_transposition = config.train['use_transposition']
107+
model_params = utils.params2dict(options.model_params)
108+
model_config.update(model_params)
53109
device = config.device
54110

55-
print('=' * 50)
56-
print('Saving path:', options.save_path)
57-
print('Dataset path:', options.data_path)
58-
print('Saving interval:', options.saving_interval)
59-
print('Event dimension:', event_dim)
60-
print('Hyperparameters:', model_config)
111+
print('-' * 50)
112+
113+
print('Session path:', sess_path)
114+
print('Dataset path:', data_path)
115+
print('Saving interval:', saving_interval)
116+
print('-' * 50)
117+
118+
print('Hyperparameters:', utils.dict2params(model_config))
61119
print('Learning rate:', learning_rate)
62120
print('Batch size:', batch_size)
63121
print('Window size:', window_size)
64122
print('Stride size:', stride_size)
123+
print('Control ratio:', control_ratio)
65124
print('Random transposition:', use_transposition)
66-
print('=' * 50)
125+
print('Reset optimizer:', reset_optimizer)
126+
print('Device:', device)
127+
print('-' * 50)
67128

68-
#========================================================================
69129

130+
#========================================================================
131+
# Load session and dataset
132+
#========================================================================
70133

71-
def load_model():
134+
def load_session():
72135
model = PerformanceRNN(**model_config).to(device)
73136
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
74137
try:
75-
sess = torch.load(save_path)
138+
sess = torch.load(sess_path)
76139
model.load_state_dict(sess['model'])
77-
optimizer.load_state_dict(sess['optimizer'])
78-
print('Session loaded')
140+
if not reset_optimizer:
141+
optimizer.load_state_dict(sess['optimizer'])
142+
print('Session is loaded from', sess_path)
79143
except:
144+
print('New session')
80145
pass
81-
print(model)
82146
return model, optimizer
83147

84148
def load_dataset():
85-
dataset = Dataset(data_path)
149+
dataset = Dataset(data_path, verbose=True)
86150
dataset_size = len(dataset.samples)
87151
assert dataset_size > 0
88-
print('Dataset size:', dataset_size)
89152
return dataset
90153

91154

92-
print('Loading model')
93-
model, optimizer = load_model()
155+
print('Loading session')
156+
model, optimizer = load_session()
157+
print(model)
158+
159+
print('-' * 50)
94160

95161
print('Loading dataset')
96162
dataset = load_dataset()
163+
print(dataset)
97164

98-
print('=' * 50)
165+
print('-' * 50)
99166

100-
#========================================================================
167+
#------------------------------------------------------------------------
101168

102169
def save_model():
103170
global model, optimizer
104-
print('Saving to', save_path)
171+
print('Saving to', sess_path)
105172
state = {'model': model.state_dict(),
106173
'optimizer': optimizer.state_dict()}
107-
torch.save(state, save_path)
174+
torch.save(state, sess_path)
108175
print('Done saving')
109176

110-
#========================================================================
111-
112-
def compute_gradient_norm(parameters, norm_type=2):
113-
total_norm = 0
114-
for p in parameters:
115-
param_norm = p.grad.data.norm(norm_type)
116-
total_norm += param_norm ** norm_type
117-
total_norm = total_norm ** (1. / norm_type)
118-
return total_norm
119177

178+
#========================================================================
179+
# Training
120180
#========================================================================
121181

122182
writer = SummaryWriter()
@@ -126,33 +186,38 @@ def compute_gradient_norm(parameters, norm_type=2):
126186
try:
127187
batch_gen = dataset.batches(batch_size, window_size, stride_size)
128188

129-
for iteration, batch in enumerate(batch_gen):
130-
events, controls = batch # [steps, batch] [steps, batch, control_dim]
131-
189+
for iteration, (events, controls) in enumerate(batch_gen):
132190
if use_transposition:
133191
offset = np.random.choice(np.arange(-6, 6))
134192
events, controls = utils.transposition(events, controls, offset)
135193

136194
events = torch.LongTensor(events).to(device)
137-
controls = torch.FloatTensor(controls).to(device)
195+
assert events.shape[0] == window_size
196+
197+
if np.random.random() < control_ratio:
198+
controls = torch.FloatTensor(controls).to(device)
199+
assert controls.shape[0] == window_size
200+
else:
201+
controls = None
202+
138203
init = torch.randn(batch_size, model.init_dim).to(device)
139204
init.requires_grad_() # start tracking the graph
140205

141-
assert window_size == events.shape[0] == controls.shape[0]
142206
outputs = model.generate(init, window_size, events[:-1], controls, output_type='logit')
143207
assert outputs.shape[:2] == events.shape[:2]
144-
loss = loss_function(outputs.view(-1, event_dim), events.view(-1))
145208

209+
loss = loss_function(outputs.view(-1, event_dim), events.view(-1))
146210
model.zero_grad()
147211
loss.backward()
148-
norm = compute_gradient_norm(model.parameters())
212+
writer.add_scalar('loss', loss.item(), iteration)
149213

150-
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
214+
# norm = utils.compute_gradient_norm(model.parameters())
215+
# nn.utils.clip_grad_norm_(model.parameters(), 1.0)
216+
# writer.add_scalar('norm', norm.item(), iteration)
217+
151218
optimizer.step()
152219

153-
writer.add_scalar('loss', loss.item(), iteration)
154-
writer.add_scalar('norm', norm.item(), iteration)
155-
print('iter {}, loss: {}'.format(iteration, loss.item()))
220+
print(f'iter {iteration}, loss: {loss.item()}')
156221

157222
if time.time() - last_saving_time > saving_interval:
158223
save_model()

0 commit comments

Comments
 (0)