@@ -145,20 +145,30 @@ def get_options():
145145#========================================================================
146146
147147def load_session ():
148- model = PerformanceRNN (** model_config ).to (device )
149- optimizer = optim .Adam (model .parameters (), lr = learning_rate )
148+ global sess_path , model_config , device , learning_rate , reset_optimizer
150149 try :
151150 sess = torch .load (sess_path )
152- model .load_state_dict (sess ['model' ])
153- if not reset_optimizer :
154- optimizer .load_state_dict (sess ['optimizer' ])
151+ if 'model_config' in sess and sess ['model_config' ] != model_config :
152+ model_config = sess ['model_config' ]
153+ print ('Use session config instead:' )
154+ print (utils .dict2params (model_config ))
155+ model_state = sess ['model_state' ]
156+ optimizer_state = sess ['optimizer_state' ]
155157 print ('Session is loaded from' , sess_path )
158+ sess_loaded = True
156159 except :
157160 print ('New session' )
158- pass
161+ sess_loaded = False
162+ model = PerformanceRNN (** model_config ).to (device )
163+ optimizer = optim .Adam (model .parameters (), lr = learning_rate )
164+ if sess_loaded :
165+ model .load_state_dict (model_state )
166+ if not reset_optimizer :
167+ optimizer .load_state_dict (optimizer_state )
159168 return model , optimizer
160169
161170def load_dataset ():
171+ global data_path
162172 dataset = Dataset (data_path , verbose = True )
163173 dataset_size = len (dataset .samples )
164174 assert dataset_size > 0
@@ -180,11 +190,11 @@ def load_dataset():
180190#------------------------------------------------------------------------
181191
182192def save_model ():
183- global model , optimizer
193+ global model , optimizer , model_config , sess_path
184194 print ('Saving to' , sess_path )
185- state = { 'model ' : model . state_dict () ,
186- 'optimizer ' : optimizer .state_dict ()}
187- torch . save ( state , sess_path )
195+ torch . save ({ 'model_config ' : model_config ,
196+ 'model_state ' : model .state_dict (),
197+ 'optimizer_state' : optimizer . state_dict ()} , sess_path )
188198 print ('Done saving' )
189199
190200
@@ -233,7 +243,7 @@ def save_model():
233243 if enable_logging :
234244 writer .add_scalar ('loss' , loss .item (), iteration )
235245 # writer.add_scalar('norm', norm.item(), iteration)
236-
246+
237247 print (f'iter { iteration } , loss: { loss .item ()} ' )
238248
239249 if time .time () - last_saving_time > saving_interval :
0 commit comments