Skip to content

Commit e866d57

Browse files
committed
[ADD] validation step to training loop. The dynamics of validation acc/loss from epoch to epoch are different keras. Still some visualiation and examination of results to ensure there isn't a bug. Still need a Testing step for parity with Keras. Also need to save model checkpoints and best overall models
1 parent 2f9eb91 commit e866d57

File tree

1 file changed

+62
-14
lines changed

1 file changed

+62
-14
lines changed

lstm_genre_classifier_pytorch.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,19 @@
4242
genre_features.load_preprocess_data()
4343

4444
train_X = torch.from_numpy(genre_features.train_X).type(torch.Tensor)
45+
dev_X = torch.from_numpy(genre_features.dev_X).type(torch.Tensor)
4546
test_X = torch.from_numpy(genre_features.test_X).type(torch.Tensor)
4647

4748
# Targets is a long tensor of size (N,) which tells the true class of the sample.
4849
train_Y = torch.from_numpy(genre_features.train_Y).type(torch.LongTensor)
50+
dev_Y = torch.from_numpy(genre_features.dev_Y).type(torch.LongTensor)
4951
test_Y = torch.from_numpy(genre_features.test_Y).type(torch.LongTensor)
5052

5153
# Convert {training, test} torch.Tensors
5254
print("Training X shape: " + str(genre_features.train_X.shape))
5355
print("Training Y shape: " + str(genre_features.train_Y.shape))
56+
print("Validation X shape: " + str(genre_features.dev_X.shape))
57+
print("Validation Y shape: " + str(genre_features.dev_Y.shape))
5458
print("Test X shape: " + str(genre_features.test_X.shape))
5559
print("Test Y shape: " + str(genre_features.test_Y.shape))
5660

@@ -86,7 +90,9 @@ def forward(self, input):
8690

8791
def get_accuracy(self, logits, target):
8892
""" compute accuracy for training round """
89-
corrects = (torch.max(logits, 1)[1].view(target.size()).data == target.data).sum()
93+
corrects = (
94+
torch.max(logits, 1)[1].view(target.size()).data == target.data
95+
).sum()
9096
accuracy = 100.0 * corrects / self.batch_size
9197
return accuracy.item()
9298

@@ -97,24 +103,25 @@ def get_accuracy(self, logits, target):
97103
# Define model
98104
print("Build LSTM RNN model ...")
99105
model = LSTM(
100-
input_dim=33,
101-
hidden_dim=128,
102-
batch_size=batch_size,
103-
output_dim=8,
104-
num_layers=2,
106+
input_dim=33, hidden_dim=128, batch_size=batch_size, output_dim=8, num_layers=2
105107
)
106108
loss_function = nn.NLLLoss()
107109
optimizer = optim.Adam(model.parameters(), lr=0.001)
108110

109-
110-
print("Training ...")
111+
train_on_gpu = torch.cuda.is_available()
112+
if train_on_gpu:
113+
print("\nTraining on GPU")
114+
else:
115+
print("\nNo GPU, training on CPU")
111116

112117
# all training data (epoch) / batch_size == num_batches (12)
113-
num_batches = int(train_X.shape[0] / batch_size)
118+
num_batches = int(train_X.shape[0] / batch_size)
119+
num_dev_batches = int(dev_X.shape[0] / batch_size)
114120

115121
for epoch in range(num_epochs):
116-
train_running_loss = 0.0
117-
train_acc = 0.0
122+
123+
print("Training ...")
124+
train_running_loss, train_acc = 0.0, 0.0
118125

119126
# Init hidden state - if you don't want a stateful LSTM (between epochs)
120127
model.hidden = model.init_hidden()
@@ -129,8 +136,8 @@ def get_accuracy(self, logits, target):
129136
# Slice out local minibatches & labels => Note that we *permute* the local minibatch to
130137
# match the PyTorch expected input tensor format of (sequence_length, batch size, input_dim)
131138
X_local_minibatch, y_local_minibatch = (
132-
train_X[i * batch_size: (i + 1) * batch_size,],
133-
train_Y[i * batch_size: (i + 1) * batch_size,]
139+
train_X[i * batch_size : (i + 1) * batch_size,],
140+
train_Y[i * batch_size : (i + 1) * batch_size,],
134141
)
135142

136143
# Reshape input & targets to "match" what the loss_function wants
@@ -144,11 +151,52 @@ def get_accuracy(self, logits, target):
144151
loss.backward() # reeeeewind (backward pass)
145152
optimizer.step() # parameter update
146153

147-
train_running_loss += loss.detach().item()
154+
train_running_loss += loss.detach().item() # unpacks the tensor into a scalar value
148155
train_acc += model.get_accuracy(y_pred, y_local_minibatch)
149156

150157
print(
151158
"Epoch: %d | NLLoss: %.4f | Train Accuracy: %.2f"
152159
% (epoch, train_running_loss / num_batches, train_acc / num_batches)
153160
)
154161

162+
print("Validation ...") # should this be done every N epochs
163+
print_every = 1
164+
165+
if epoch % print_every == 0:
166+
167+
# Get validation loss
168+
with torch.no_grad():
169+
val_running_loss, val_acc = 0.0, 0.0
170+
model.eval()
171+
172+
model.hidden = model.init_hidden()
173+
for i in range(num_dev_batches):
174+
X_local_validation_minibatch, y_local_validation_minibatch = (
175+
dev_X[i * batch_size : (i + 1) * batch_size,],
176+
dev_Y[i * batch_size : (i + 1) * batch_size,],
177+
)
178+
X_local_minibatch = X_local_validation_minibatch.permute(1, 0, 2)
179+
y_local_minibatch = torch.max(y_local_validation_minibatch, 1)[1]
180+
181+
y_pred = model(X_local_minibatch)
182+
val_loss = loss_function(y_pred, y_local_minibatch)
183+
184+
val_running_loss += (
185+
val_loss.detach().item()
186+
) # unpacks the tensor into a scalar value
187+
val_acc += model.get_accuracy(y_pred, y_local_minibatch)
188+
189+
model.train() # reset to train mode after iterationg through validation data
190+
print(
191+
"Epoch: %d | NLLoss: %.4f | Train Accuracy: %.2f | Val Loss %.4f | Val Accuracy: %.2f"
192+
% (
193+
epoch,
194+
train_running_loss / num_batches,
195+
train_acc / num_batches,
196+
val_running_loss / num_dev_batches,
197+
val_acc / num_dev_batches,
198+
)
199+
)
200+
201+
202+
print("Testing ...")

0 commit comments

Comments
 (0)