Skip to content

Commit e46e129

Browse files
committed
[CLEANUP] logging, attempts to debug validation progress bar issues & warnings, as well as proper per-step, per-epoch reporting
1 parent a5a24fb commit e46e129

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

lstm_genre_classifier_pytorch_lightning.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def validation_step(self, batch, batch_idx):
9292
if not self.stateful:
9393
self.hidden = None
9494

95-
y_local_minibatch = torch.max(y_local_minibatch, 1)[1] # NLLLoss expects class indices
95+
y_local_minibatch = torch.max(y_local_minibatch, 1)[1] # NLLLoss expects class indices
9696
val_loss = self.loss_function(y_pred, y_local_minibatch) # compute loss
97+
self.log("val_loss", val_loss)
9798
return val_loss
9899

99100
def configure_optimizers(self):
@@ -103,7 +104,7 @@ def configure_optimizers(self):
103104

104105
class GTZANDataset(data.Dataset):
105106

106-
def __init__(self, partition):
107+
def __init__(self, partition) -> object:
107108
self.partition = partition
108109
self.genre_features = GenreFeatureData()
109110

@@ -132,12 +133,16 @@ def __init__(self, partition):
132133
self.test_Y = torch.from_numpy(self.genre_features.test_Y).type(torch.LongTensor)
133134

134135
# Convert {training, test} torch.Tensors
135-
print("Training X shape: " + str(self.genre_features.train_X.shape))
136-
print("Training Y shape: " + str(self.genre_features.train_Y.shape))
137-
print("Validation X shape: " + str(self.genre_features.dev_X.shape))
138-
print("Validation Y shape: " + str(self.genre_features.dev_Y.shape))
139-
print("Test X shape: " + str(self.genre_features.test_X.shape))
140-
print("Test Y shape: " + str(self.genre_features.test_Y.shape))
136+
if self.partition == 'train':
137+
print("Training X shape: " + str(self.genre_features.train_X.shape))
138+
print("Training Y shape: " + str(self.genre_features.train_Y.shape))
139+
elif self.partition == 'dev':
140+
print("Validation X shape: " + str(self.genre_features.dev_X.shape))
141+
print("Validation Y shape: " + str(self.genre_features.dev_Y.shape))
142+
elif self.partition == 'test':
143+
print("Test X shape: " + str(self.genre_features.test_X.shape))
144+
print("Test Y shape: " + str(self.genre_features.test_Y.shape))
145+
141146

142147
def __getitem__(self, index):
143148
# train_X shape: (total # of training examples, sequence_length, input_dim)
@@ -146,7 +151,6 @@ def __getitem__(self, index):
146151
if self.partition == 'train':
147152
X_training_example_at_index = self.train_X[index, ]
148153
y_training_example_at_index = self.train_Y[index, ]
149-
# torch.index_select(self.train_X, dim=0, index=torch.tensor([index], dtype=torch.long))
150154

151155
elif self.partition == 'dev':
152156
X_training_example_at_index = self.dev_X[index, ]
@@ -170,7 +174,10 @@ def __len__(self):
170174
class MusicGenreDataModule(pl.LightningDataModule):
171175
def __init__(self, batch_size: int = 35) -> None:
172176
super().__init__()
173-
self.batch_size = batch_size # num of training examples per minibatch
177+
self.dev_dataset = None
178+
self.test_dataset = None
179+
self.train_dataset = None
180+
self.batch_size = batch_size # num of training examples per minibatch
174181

175182
def setup(self, stage):
176183
self.train_dataset = GTZANDataset('train')
@@ -189,12 +196,9 @@ def test_dataloader(self) -> DataLoader:
189196

190197
if __name__ == "__main__":
191198
model = MusicGenreClassifer()
192-
trainer = pl.Trainer()
199+
trainer = pl.Trainer(max_epochs=5)
193200
genre_dm = MusicGenreDataModule()
194201

195-
# dataloader fit
196-
# trainer.fit(model, train_loader, val_loader)
197-
198202
# datamodel fit
199203
trainer.fit(model, genre_dm)
200204

0 commit comments

Comments
 (0)