Skip to content

Commit ad2b6c4

Browse files
committed
Fixed bug
1 parent cbe3b98 commit ad2b6c4

File tree

4 files changed

+113
-13
lines changed

4 files changed

+113
-13
lines changed

notebooks/case_study_circle.ipynb

Lines changed: 104 additions & 7 deletions
Large diffs are not rendered by default.

src/run_exp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
lr = 0.002
4141
weight_decay = 0.1
4242

43-
n_exp=1
43+
n_exp=4
4444

4545
param_dict = {
4646
'seed': seed,

src/utils/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def train_single_model(param_dict: dict):
120120
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
121121

122122
ret_dic = {}
123-
ret_dic["results"] = model.train(param_dict={'num_epochs': num_epochs, 'learning_rate': lr, 'weight_decay':weight_decay, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device, 'video': video, 'verbose': verbose, 'lambda': lamb_reg})
123+
ret_dic["results"] = model.train(param_dict={'model_id':model_id,'num_epochs': num_epochs, 'learning_rate': lr, 'weight_decay':weight_decay, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device, 'video': video, 'verbose': verbose, 'lambda': lamb_reg})
124124
ret_dic["model"] = model
125125
ret_dic["dataset"] = dataset
126126

src/utils/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def train(self, param_dict: dict):
2525
test_dataloader = param_dict['test_dataloader']
2626
device = param_dict['device']
2727
weight_decay = param_dict['weight_decay']
28+
model_id = param_dict['model_id']
2829
video = False if 'video' not in param_dict else param_dict['video']
2930

3031
verbose = True
@@ -61,14 +62,16 @@ def train(self, param_dict: dict):
6162
batch_inputs = batch_inputs.to(device)
6263
batch_targets = batch_targets.type(torch.LongTensor).to(device)
6364
optimizer.zero_grad()
64-
logits = self.forward(batch_inputs)
65+
outputs = self.forward(batch_inputs)
6566

6667
# class_counts = torch.bincount(batch_targets.squeeze(), minlength=self.vocab_size).double() + 1e-8
6768
# class_weights = 1 / class_counts.cuda()
6869

6970
criterion = nn.CrossEntropyLoss()#weight=class_weights)
70-
71-
loss = criterion(logits, batch_targets.squeeze())
71+
if 'H_' in model_id: # Harmonic Model
72+
loss = (-1)*(outputs[torch.arange(outputs.size(0)), batch_targets.squeeze()].mean())
73+
else:
74+
loss = criterion(outputs, batch_targets.squeeze())
7275

7376
if hasattr(self.embedding, 'weight'):
7477
total_loss = loss + lamb_reg * torch.mean(torch.sqrt(torch.mean(self.embedding.weight**2, dim=0)))
@@ -80,7 +83,7 @@ def train(self, param_dict: dict):
8083
train_loss += loss.item()
8184

8285
# Compute training accuracy
83-
_, predicted = torch.max(logits, 1)
86+
_, predicted = torch.max(outputs, 1)
8487
train_correct += (predicted == batch_targets).sum().item()
8588
train_total += batch_targets.size(0)
8689

0 commit comments

Comments
 (0)