@@ -92,8 +92,9 @@ def validation_step(self, batch, batch_idx):
9292 if not self .stateful :
9393 self .hidden = None
9494
95- y_local_minibatch = torch .max (y_local_minibatch , 1 )[1 ] # NLLLoss expects class indices
95+ y_local_minibatch = torch .max (y_local_minibatch , 1 )[1 ] # NLLLoss expects class indices
9696 val_loss = self .loss_function (y_pred , y_local_minibatch ) # compute loss
97+ self .log ("val_loss" , val_loss )
9798 return val_loss
9899
99100 def configure_optimizers (self ):
@@ -103,7 +104,7 @@ def configure_optimizers(self):
103104
104105class GTZANDataset (data .Dataset ):
105106
106- def __init__ (self , partition ):
107+ def __init__ (self , partition ) -> object :
107108 self .partition = partition
108109 self .genre_features = GenreFeatureData ()
109110
@@ -132,12 +133,16 @@ def __init__(self, partition):
132133 self .test_Y = torch .from_numpy (self .genre_features .test_Y ).type (torch .LongTensor )
133134
134135 # Convert {training, test} torch.Tensors
135- print ("Training X shape: " + str (self .genre_features .train_X .shape ))
136- print ("Training Y shape: " + str (self .genre_features .train_Y .shape ))
137- print ("Validation X shape: " + str (self .genre_features .dev_X .shape ))
138- print ("Validation Y shape: " + str (self .genre_features .dev_Y .shape ))
139- print ("Test X shape: " + str (self .genre_features .test_X .shape ))
140- print ("Test Y shape: " + str (self .genre_features .test_Y .shape ))
136+ if self .partition == 'train' :
137+ print ("Training X shape: " + str (self .genre_features .train_X .shape ))
138+ print ("Training Y shape: " + str (self .genre_features .train_Y .shape ))
139+ elif self .partition == 'dev' :
140+ print ("Validation X shape: " + str (self .genre_features .dev_X .shape ))
141+ print ("Validation Y shape: " + str (self .genre_features .dev_Y .shape ))
142+ elif self .partition == 'test' :
143+ print ("Test X shape: " + str (self .genre_features .test_X .shape ))
144+ print ("Test Y shape: " + str (self .genre_features .test_Y .shape ))
145+
141146
142147 def __getitem__ (self , index ):
143148 # train_X shape: (total # of training examples, sequence_length, input_dim)
@@ -146,7 +151,6 @@ def __getitem__(self, index):
146151 if self .partition == 'train' :
147152 X_training_example_at_index = self .train_X [index , ]
148153 y_training_example_at_index = self .train_Y [index , ]
149- # torch.index_select(self.train_X, dim=0, index=torch.tensor([index], dtype=torch.long))
150154
151155 elif self .partition == 'dev' :
152156 X_training_example_at_index = self .dev_X [index , ]
@@ -170,7 +174,10 @@ def __len__(self):
170174class MusicGenreDataModule (pl .LightningDataModule ):
171175 def __init__ (self , batch_size : int = 35 ) -> None :
172176 super ().__init__ ()
173- self .batch_size = batch_size # num of training examples per minibatch
177+ self .dev_dataset = None
178+ self .test_dataset = None
179+ self .train_dataset = None
180+ self .batch_size = batch_size # num of training examples per minibatch
174181
175182 def setup (self , stage ):
176183 self .train_dataset = GTZANDataset ('train' )
@@ -189,12 +196,9 @@ def test_dataloader(self) -> DataLoader:
189196
190197if __name__ == "__main__" :
191198 model = MusicGenreClassifer ()
192- trainer = pl .Trainer ()
199+ trainer = pl .Trainer (max_epochs = 5 )
193200 genre_dm = MusicGenreDataModule ()
194201
195- # dataloader fit
196- # trainer.fit(model, train_loader, val_loader)
197-
198202 # datamodel fit
199203 trainer .fit (model , genre_dm )
200204
0 commit comments