Skip to content
This repository was archived by the owner on Jan 26, 2024. It is now read-only.

Commit 0a6d271

Browse files
committed
Fix session loading problems
1 parent f5c69ed commit 0a6d271

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

generate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,12 @@ def getopt():
151151
# Generating
152152
#========================================================================
153153

154-
model = PerformanceRNN(**model_config).to(device)
155-
model.load_state_dict(torch.load(sess_path)['model'])
154+
state = torch.load(sess_path)
155+
model = PerformanceRNN(**state['model_config']).to(device)
156+
model.load_state_dict(state['model_state'])
156157
model.eval()
158+
print(model)
159+
print('-' * 50)
157160

158161
# Don't build the graph
159162
for parameter in model.parameters():

train.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,30 @@ def get_options():
145145
#========================================================================
146146

147147
def load_session():
148-
model = PerformanceRNN(**model_config).to(device)
149-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
148+
global sess_path, model_config, device, learning_rate, reset_optimizer
150149
try:
151150
sess = torch.load(sess_path)
152-
model.load_state_dict(sess['model'])
153-
if not reset_optimizer:
154-
optimizer.load_state_dict(sess['optimizer'])
151+
if 'model_config' in sess and sess['model_config'] != model_config:
152+
model_config = sess['model_config']
153+
print('Use session config instead:')
154+
print(utils.dict2params(model_config))
155+
model_state = sess['model_state']
156+
optimizer_state = sess['optimizer_state']
155157
print('Session is loaded from', sess_path)
158+
sess_loaded = True
156159
except:
157160
print('New session')
158-
pass
161+
sess_loaded = False
162+
model = PerformanceRNN(**model_config).to(device)
163+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
164+
if sess_loaded:
165+
model.load_state_dict(model_state)
166+
if not reset_optimizer:
167+
optimizer.load_state_dict(optimizer_state)
159168
return model, optimizer
160169

161170
def load_dataset():
171+
global data_path
162172
dataset = Dataset(data_path, verbose=True)
163173
dataset_size = len(dataset.samples)
164174
assert dataset_size > 0
@@ -180,11 +190,11 @@ def load_dataset():
180190
#------------------------------------------------------------------------
181191

182192
def save_model():
183-
global model, optimizer
193+
global model, optimizer, model_config, sess_path
184194
print('Saving to', sess_path)
185-
state = {'model': model.state_dict(),
186-
'optimizer': optimizer.state_dict()}
187-
torch.save(state, sess_path)
195+
torch.save({'model_config': model_config,
196+
'model_state': model.state_dict(),
197+
'optimizer_state': optimizer.state_dict()}, sess_path)
188198
print('Done saving')
189199

190200

@@ -233,7 +243,7 @@ def save_model():
233243
if enable_logging:
234244
writer.add_scalar('loss', loss.item(), iteration)
235245
# writer.add_scalar('norm', norm.item(), iteration)
236-
246+
237247
print(f'iter {iteration}, loss: {loss.item()}')
238248

239249
if time.time() - last_saving_time > saving_interval:

0 commit comments

Comments
 (0)