@@ -25,6 +25,7 @@ def train(self, param_dict: dict):
2525 test_dataloader = param_dict ['test_dataloader' ]
2626 device = param_dict ['device' ]
2727 weight_decay = param_dict ['weight_decay' ]
28+ model_id = param_dict ['model_id' ]
2829 video = False if 'video' not in param_dict else param_dict ['video' ]
2930
3031 verbose = True
@@ -61,14 +62,16 @@ def train(self, param_dict: dict):
6162 batch_inputs = batch_inputs .to (device )
6263 batch_targets = batch_targets .type (torch .LongTensor ).to (device )
6364 optimizer .zero_grad ()
64- logits = self .forward (batch_inputs )
65+ outputs = self .forward (batch_inputs )
6566
6667# class_counts = torch.bincount(batch_targets.squeeze(), minlength=self.vocab_size).double() + 1e-8
6768# class_weights = 1 / class_counts.cuda()
6869
6970 criterion = nn .CrossEntropyLoss ()#weight=class_weights)
70-
71- loss = criterion (logits , batch_targets .squeeze ())
71+ if 'H_' in model_id : # Harmonic Model
72+ loss = (- 1 )* (outputs [torch .arange (outputs .size (0 )), batch_targets .squeeze ()].mean ())
73+ else :
74+ loss = criterion (outputs , batch_targets .squeeze ())
7275
7376 if hasattr (self .embedding , 'weight' ):
7477 total_loss = loss + lamb_reg * torch .mean (torch .sqrt (torch .mean (self .embedding .weight ** 2 , dim = 0 )))
@@ -80,7 +83,7 @@ def train(self, param_dict: dict):
8083 train_loss += loss .item ()
8184
8285 # Compute training accuracy
83- _ , predicted = torch .max (logits , 1 )
86+ _ , predicted = torch .max (outputs , 1 )
8487 train_correct += (predicted == batch_targets ).sum ().item ()
8588 train_total += batch_targets .size (0 )
8689
0 commit comments