Skip to content

Commit 26c166d

Browse files
committed
Updated figures
1 parent 7db84d0 commit 26c166d

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed
-52 Bytes
Loading
190 Bytes
Loading

src/run_exp.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch.optim as optim
44
import numpy as np
55
import random
6+
import optuna
7+
import joblib
68

79
from tqdm import tqdm
810

@@ -20,8 +22,8 @@
2022

2123
data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle", "permutation"]
2224
model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"]
23-
split_choices = [1,2,3,4,5,6,7]
24-
wd_choices = [0.003, 0.005, 0.007, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.07, 0.1]
25+
split_choices = [1,2,3,4,5,6,7, 8]
26+
wd_choices = [0.0005, 0.001, 0.003, 0.005, 0.007, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.07, 0.1]
2527
if __name__ == '__main__':
2628
parser = argparse.ArgumentParser(description='Experiment')
2729
parser.add_argument('--seed', type=int, default=66, help='random seed')
@@ -43,7 +45,7 @@
4345
embd_dim = 16
4446

4547
lr = 0.002
46-
weight_decay = 0.005
48+
weight_decay = args.wd
4749

4850
n_exp=1
4951

@@ -89,6 +91,36 @@
8991
else:
9092
raise ValueError(f"Unknown data_id: {data_id}")
9193

94+
# # Optuna study for lr/wd
95+
# def loss_objective(trial):
96+
# weight_decay = trial.suggest_float('wd', 0, 0.01)
97+
# lr = trial.suggest_float('lr', 0.002, 0.005)
98+
99+
# param_dict = {
100+
# 'seed': seed,
101+
# 'data_id': data_id,
102+
# 'data_size': data_size,
103+
# 'train_ratio': train_ratio,
104+
# 'model_id': model_id,
105+
# 'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
106+
# 'embd_dim': embd_dim,
107+
# 'n_exp': n_exp,
108+
# 'lr': lr,
109+
# 'weight_decay':weight_decay
110+
# }
111+
112+
# ret_dic = train_single_model(param_dict)
113+
114+
# test_loss = np.mean(ret_dic["results"]["test_losses"][-10:])
115+
116+
# return test_loss
117+
118+
# study = optuna.create_study()
119+
# study.optimize(loss_objective, n_trials = 15)
120+
# joblib.dump(study, "wd_lr_study.pkl")
121+
122+
# print(study.best_params)
123+
92124
# # Train the model
93125
# print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}, weight decay {weight_decay}")
94126
# ret_dic = train_single_model(param_dict)
@@ -209,9 +241,9 @@
209241
seed_list = np.linspace(0, 1000, 20, dtype=int)[7:10]
210242
if split == 6:
211243
seed_list = np.linspace(0, 1000, 20, dtype=int)[10:13]
212-
if split == 5:
213-
seed_list = np.linspace(0, 1000, 20, dtype=int)[13:17]
214244
if split == 7:
245+
seed_list = np.linspace(0, 1000, 20, dtype=int)[13:17]
246+
if split == 8:
215247
seed_list = np.linspace(0, 1000, 20, dtype=int)[17:]
216248

217249

0 commit comments

Comments
 (0)