Skip to content

Commit c32344e

Browse files
committed
[ADD] a few cleanups to logging to stdout, as well as hidden -> hidden_state, to not have interscope shadowing
1 parent 8fe3443 commit c32344e

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

lstm_genre_classifier_pytorch.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import os
1818
import torch
19-
import numpy as np
2019
import torch.nn as nn
2120
import torch.nn.functional as F
2221
import torch.optim as optim
@@ -127,10 +126,10 @@ def get_accuracy(self, logits, target):
127126
train_running_loss, train_acc = 0.0, 0.0
128127

129128
# Init hidden state - if you don't want a stateful LSTM (between epochs)
130-
hidden = None
129+
hidden_state = None
131130
for i in range(num_batches):
132131

133-
# zero out gradient, so they don't accumulate btw epochs
132+
# zero out gradient, so they don't accumulate btw batches
134133
model.zero_grad()
135134

136135
# train_X shape: (total # of training examples, sequence_length, input_dim)
@@ -149,18 +148,18 @@ def get_accuracy(self, logits, target):
149148
# NLLLoss does not expect a one-hot encoded vector as the target, but class indices
150149
y_local_minibatch = torch.max(y_local_minibatch, 1)[1]
151150

152-
y_pred, hidden = model(X_local_minibatch, hidden) # forward pass
151+
y_pred, hidden_state = model(X_local_minibatch, hidden_state) # forward pass
153152

154153
# Stateful = False for training. Do we go Stateful = True during inference/prediction time?
155154
if not stateful:
156-
hidden = None
155+
hidden_state = None
157156
else:
158-
h_0, c_0 = hidden
157+
h_0, c_0 = hidden_state
159158
h_0.detach_(), c_0.detach_()
160-
hidden = (h_0, c_0)
159+
hidden_state = (h_0, c_0)
161160

162161
loss = loss_function(y_pred, y_local_minibatch) # compute loss
163-
loss.backward() # reeeeewind (backward pass)
162+
loss.backward() # backward pass
164163
optimizer.step() # parameter update
165164

166165
train_running_loss += loss.detach().item() # unpacks the tensor into a scalar value
@@ -171,15 +170,15 @@ def get_accuracy(self, logits, target):
171170
% (epoch, train_running_loss / num_batches, train_acc / num_batches)
172171
)
173172

174-
print("Validation ...") # should this be done every N epochs
175173
if epoch % 10 == 0:
174+
print("Validation ...") # should this be done every N=10 epochs
176175
val_running_loss, val_acc = 0.0, 0.0
177176

178177
# Compute validation loss, accuracy. Use torch.no_grad() & model.eval()
179178
with torch.no_grad():
180179
model.eval()
181180

182-
hidden = None
181+
hidden_state = None
183182
for i in range(num_dev_batches):
184183
X_local_validation_minibatch, y_local_validation_minibatch = (
185184
dev_X[i * batch_size: (i + 1) * batch_size, ],
@@ -188,9 +187,9 @@ def get_accuracy(self, logits, target):
188187
X_local_minibatch = X_local_validation_minibatch.permute(1, 0, 2)
189188
y_local_minibatch = torch.max(y_local_validation_minibatch, 1)[1]
190189

191-
y_pred, hidden = model(X_local_minibatch, hidden)
190+
y_pred, hidden_state = model(X_local_minibatch, hidden_state)
192191
if not stateful:
193-
hidden = None
192+
hidden_state = None
194193

195194
val_loss = loss_function(y_pred, y_local_minibatch)
196195

0 commit comments

Comments
 (0)