44from torch .autograd import Variable
55
66import numpy as np
7- import sys , os , time , optparse
7+
8+ import os
9+ import sys
10+ import time
11+ import optparse
812
913from tensorboardX import SummaryWriter
1014
11- import config , utils
15+ import utils
16+ import config
17+ from data import Dataset
1218from model import PerformanceRNN
1319from sequence import NoteSeq , EventSeq , ControlSeq
14- from data import Dataset
1520
1621# pylint: disable=E1102
1722# pylint: disable=E1101
1823
24+ #========================================================================
25+ # Settings
1926#========================================================================
2027
2128def get_options ():
2229 parser = optparse .OptionParser ()
23- parser .add_option ('-s' ,
24- dest = 'save_path' ,
25- type = 'string' ,
26- default = 'train.sess' )
27- parser .add_option ('-d' ,
28- dest = 'data_path' ,
29- type = 'string' ,
30- default = 'dataset/processed/' )
31- parser .add_option ('-i' ,
32- dest = 'saving_interval' ,
33- type = 'float' ,
34- default = 60. )
30+
31+ parser .add_option ('-s' , '--session' ,
32+ dest = 'sess_path' ,
33+ type = 'string' ,
34+ default = 'train.sess' )
35+
36+ parser .add_option ('-d' , '--dataset' ,
37+ dest = 'data_path' ,
38+ type = 'string' ,
39+ default = 'dataset/processed/' )
40+
41+ parser .add_option ('-i' , '--saving-interval' ,
42+ dest = 'saving_interval' ,
43+ type = 'float' ,
44+ default = 60. )
45+
46+ parser .add_option ('-b' , '--batch-size' ,
47+ dest = 'batch_size' ,
48+ type = 'int' ,
49+ default = config .train ['batch_size' ])
50+
51+ parser .add_option ('-l' , '--learning-rate' ,
52+ dest = 'learning_rate' ,
53+ type = 'float' ,
54+ default = config .train ['learning_rate' ])
55+
56+ parser .add_option ('-w' , '--window-size' ,
57+ dest = 'window_size' ,
58+ type = 'int' ,
59+ default = config .train ['window_size' ])
60+
61+ parser .add_option ('-z' , '--stride-size' ,
62+ dest = 'stride_size' ,
63+ type = 'int' ,
64+ default = config .train ['stride_size' ])
65+
66+ parser .add_option ('-c' , '--control-ratio' ,
67+ dest = 'control_ratio' ,
68+ type = 'float' ,
69+ default = config .train ['control_ratio' ])
70+
71+ parser .add_option ('-t' , '--use-transposition' ,
72+ dest = 'use_transposition' ,
73+ action = 'store_true' ,
74+ default = config .train ['use_transposition' ])
75+
76+ parser .add_option ('-p' , '--model-params' ,
77+ dest = 'model_params' ,
78+ type = 'string' ,
79+ default = '' )
80+
81+ parser .add_option ('-r' , '--reset-optimizer' ,
82+ dest = 'reset_optimizer' ,
83+ action = 'store_true' ,
84+ default = False )
85+
3586 return parser .parse_args ()[0 ]
3687
3788options = get_options ()
3889
39- #========================================================================
90+ #------------------------------------------------------------------------
4091
41- save_path = options .save_path
92+ sess_path = options .sess_path
4293data_path = options .data_path
4394saving_interval = options .saving_interval
4495
96+ learning_rate = options .learning_rate
97+ batch_size = options .batch_size
98+ window_size = options .window_size
99+ stride_size = options .stride_size
100+ use_transposition = options .use_transposition
101+ control_ratio = options .control_ratio
102+ reset_optimizer = options .reset_optimizer
103+
45104event_dim = EventSeq .dim ()
46105control_dim = ControlSeq .dim ()
47106model_config = config .model
48- learning_rate = config .train ['learning_rate' ]
49- batch_size = config .train ['batch_size' ]
50- window_size = config .train ['window_size' ]
51- stride_size = config .train ['stride_size' ]
52- use_transposition = config .train ['use_transposition' ]
107+ model_params = utils .params2dict (options .model_params )
108+ model_config .update (model_params )
53109device = config .device
54110
55- print ('=' * 50 )
56- print ('Saving path:' , options .save_path )
57- print ('Dataset path:' , options .data_path )
58- print ('Saving interval:' , options .saving_interval )
59- print ('Event dimension:' , event_dim )
60- print ('Hyperparameters:' , model_config )
111+ print ('-' * 50 )
112+
113+ print ('Session path:' , sess_path )
114+ print ('Dataset path:' , data_path )
115+ print ('Saving interval:' , saving_interval )
116+ print ('-' * 50 )
117+
118+ print ('Hyperparameters:' , utils .dict2params (model_config ))
61119print ('Learning rate:' , learning_rate )
62120print ('Batch size:' , batch_size )
63121print ('Window size:' , window_size )
64122print ('Stride size:' , stride_size )
123+ print ('Control ratio:' , control_ratio )
65124print ('Random transposition:' , use_transposition )
66- print ('=' * 50 )
125+ print ('Reset optimizer:' , reset_optimizer )
126+ print ('Device:' , device )
127+ print ('-' * 50 )
67128
68- #========================================================================
69129
130+ #========================================================================
131+ # Load session and dataset
132+ #========================================================================
70133
71- def load_model ():
134+ def load_session ():
72135 model = PerformanceRNN (** model_config ).to (device )
73136 optimizer = optim .Adam (model .parameters (), lr = learning_rate )
74137 try :
75- sess = torch .load (save_path )
138+ sess = torch .load (sess_path )
76139 model .load_state_dict (sess ['model' ])
77- optimizer .load_state_dict (sess ['optimizer' ])
78- print ('Session loaded' )
140+ if not reset_optimizer :
141+ optimizer .load_state_dict (sess ['optimizer' ])
142+ print ('Session is loaded from' , sess_path )
79143 except :
144+ print ('New session' )
80145 pass
81- print (model )
82146 return model , optimizer
83147
84148def load_dataset ():
85- dataset = Dataset (data_path )
149+ dataset = Dataset (data_path , verbose = True )
86150 dataset_size = len (dataset .samples )
87151 assert dataset_size > 0
88- print ('Dataset size:' , dataset_size )
89152 return dataset
90153
91154
92- print ('Loading model' )
93- model , optimizer = load_model ()
155+ print ('Loading session' )
156+ model , optimizer = load_session ()
157+ print (model )
158+
159+ print ('-' * 50 )
94160
95161print ('Loading dataset' )
96162dataset = load_dataset ()
163+ print (dataset )
97164
98- print ('= ' * 50 )
165+ print ('- ' * 50 )
99166
100- #========================================================================
167+ #------------------------------------------------------------------------
101168
102169def save_model ():
103170 global model , optimizer
104- print ('Saving to' , save_path )
171+ print ('Saving to' , sess_path )
105172 state = {'model' : model .state_dict (),
106173 'optimizer' : optimizer .state_dict ()}
107- torch .save (state , save_path )
174+ torch .save (state , sess_path )
108175 print ('Done saving' )
109176
110- #========================================================================
111-
112- def compute_gradient_norm (parameters , norm_type = 2 ):
113- total_norm = 0
114- for p in parameters :
115- param_norm = p .grad .data .norm (norm_type )
116- total_norm += param_norm ** norm_type
117- total_norm = total_norm ** (1. / norm_type )
118- return total_norm
119177
178+ #========================================================================
179+ # Training
120180#========================================================================
121181
122182writer = SummaryWriter ()
@@ -126,33 +186,38 @@ def compute_gradient_norm(parameters, norm_type=2):
126186try :
127187 batch_gen = dataset .batches (batch_size , window_size , stride_size )
128188
129- for iteration , batch in enumerate (batch_gen ):
130- events , controls = batch # [steps, batch] [steps, batch, control_dim]
131-
189+ for iteration , (events , controls ) in enumerate (batch_gen ):
132190 if use_transposition :
133191 offset = np .random .choice (np .arange (- 6 , 6 ))
134192 events , controls = utils .transposition (events , controls , offset )
135193
136194 events = torch .LongTensor (events ).to (device )
137- controls = torch .FloatTensor (controls ).to (device )
195+ assert events .shape [0 ] == window_size
196+
197+ if np .random .random () < control_ratio :
198+ controls = torch .FloatTensor (controls ).to (device )
199+ assert controls .shape [0 ] == window_size
200+ else :
201+ controls = None
202+
138203 init = torch .randn (batch_size , model .init_dim ).to (device )
139204 init .requires_grad_ () # start tracking the graph
140205
141- assert window_size == events .shape [0 ] == controls .shape [0 ]
142206 outputs = model .generate (init , window_size , events [:- 1 ], controls , output_type = 'logit' )
143207 assert outputs .shape [:2 ] == events .shape [:2 ]
144- loss = loss_function (outputs .view (- 1 , event_dim ), events .view (- 1 ))
145208
209+ loss = loss_function (outputs .view (- 1 , event_dim ), events .view (- 1 ))
146210 model .zero_grad ()
147211 loss .backward ()
148- norm = compute_gradient_norm ( model . parameters () )
212+ writer . add_scalar ( 'loss' , loss . item (), iteration )
149213
150- nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
214+ # norm = utils.compute_gradient_norm(model.parameters())
215+ # nn.utils.clip_grad_norm_(model.parameters(), 1.0)
216+ # writer.add_scalar('norm', norm.item(), iteration)
217+
151218 optimizer .step ()
152219
153- writer .add_scalar ('loss' , loss .item (), iteration )
154- writer .add_scalar ('norm' , norm .item (), iteration )
155- print ('iter {}, loss: {}' .format (iteration , loss .item ()))
220+ print (f'iter { iteration } , loss: { loss .item ()} ' )
156221
157222 if time .time () - last_saving_time > saving_interval :
158223 save_model ()
0 commit comments