1616
1717import os
1818import torch
19- import numpy as np
2019import torch .nn as nn
2120import torch .nn .functional as F
2221import torch .optim as optim
@@ -127,10 +126,10 @@ def get_accuracy(self, logits, target):
127126 train_running_loss , train_acc = 0.0 , 0.0
128127
129128 # Init hidden state - if you don't want a stateful LSTM (between epochs)
130- hidden = None
129+ hidden_state = None
131130 for i in range (num_batches ):
132131
133- # zero out gradient, so they don't accumulate btw epochs
132+ # zero out gradient, so they don't accumulate btw batches
134133 model .zero_grad ()
135134
136135 # train_X shape: (total # of training examples, sequence_length, input_dim)
@@ -149,18 +148,18 @@ def get_accuracy(self, logits, target):
149148 # NLLLoss does not expect a one-hot encoded vector as the target, but class indices
150149 y_local_minibatch = torch .max (y_local_minibatch , 1 )[1 ]
151150
152- y_pred , hidden = model (X_local_minibatch , hidden ) # forward pass
151+ y_pred , hidden_state = model (X_local_minibatch , hidden_state ) # forward pass
153152
154153 # Stateful = False for training. Do we go Stateful = True during inference/prediction time?
155154 if not stateful :
156- hidden = None
155+ hidden_state = None
157156 else :
158- h_0 , c_0 = hidden
157+ h_0 , c_0 = hidden_state
159158 h_0 .detach_ (), c_0 .detach_ ()
160- hidden = (h_0 , c_0 )
159+ hidden_state = (h_0 , c_0 )
161160
162161 loss = loss_function (y_pred , y_local_minibatch ) # compute loss
163- loss .backward () # reeeeewind ( backward pass)
162+ loss .backward () # backward pass
164163 optimizer .step () # parameter update
165164
166165 train_running_loss += loss .detach ().item () # unpacks the tensor into a scalar value
@@ -171,15 +170,15 @@ def get_accuracy(self, logits, target):
171170 % (epoch , train_running_loss / num_batches , train_acc / num_batches )
172171 )
173172
174- print ("Validation ..." ) # should this be done every N epochs
175173 if epoch % 10 == 0 :
174+ print ("Validation ..." ) # should this be done every N=10 epochs
176175 val_running_loss , val_acc = 0.0 , 0.0
177176
178177 # Compute validation loss, accuracy. Use torch.no_grad() & model.eval()
179178 with torch .no_grad ():
180179 model .eval ()
181180
182- hidden = None
181+ hidden_state = None
183182 for i in range (num_dev_batches ):
184183 X_local_validation_minibatch , y_local_validation_minibatch = (
185184 dev_X [i * batch_size : (i + 1 ) * batch_size , ],
@@ -188,9 +187,9 @@ def get_accuracy(self, logits, target):
188187 X_local_minibatch = X_local_validation_minibatch .permute (1 , 0 , 2 )
189188 y_local_minibatch = torch .max (y_local_validation_minibatch , 1 )[1 ]
190189
191- y_pred , hidden = model (X_local_minibatch , hidden )
190+ y_pred , hidden_state = model (X_local_minibatch , hidden_state )
192191 if not stateful :
193- hidden = None
192+ hidden_state = None
194193
195194 val_loss = loss_function (y_pred , y_local_minibatch )
196195
0 commit comments