|
3 | 3 | import torch.optim as optim |
4 | 4 | import numpy as np |
5 | 5 | import random |
| 6 | +import optuna |
| 7 | +import joblib |
6 | 8 |
|
7 | 9 | from tqdm import tqdm |
8 | 10 |
|
|
20 | 22 |
|
21 | 23 | data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle", "permutation"] |
22 | 24 | model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"] |
23 | | -split_choices = [1,2,3,4,5,6,7] |
24 | | -wd_choices = [0.003, 0.005, 0.007, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.07, 0.1] |
| 25 | +split_choices = [1,2,3,4,5,6,7, 8] |
| 26 | +wd_choices = [0.0005, 0.001, 0.003, 0.005, 0.007, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.07, 0.1] |
25 | 27 | if __name__ == '__main__': |
26 | 28 | parser = argparse.ArgumentParser(description='Experiment') |
27 | 29 | parser.add_argument('--seed', type=int, default=66, help='random seed') |
|
43 | 45 | embd_dim = 16 |
44 | 46 |
|
45 | 47 | lr = 0.002 |
46 | | -weight_decay = 0.005 |
| 48 | +weight_decay = args.wd |
47 | 49 |
|
48 | 50 | n_exp=1 |
49 | 51 |
|
|
89 | 91 | else: |
90 | 92 | raise ValueError(f"Unknown data_id: {data_id}") |
91 | 93 |
|
| 94 | +# # Optuna study for lr/wd |
| 95 | +# def loss_objective(trial): |
| 96 | +# weight_decay = trial.suggest_float('wd', 0, 0.01) |
| 97 | +# lr = trial.suggest_float('lr', 0.002, 0.005) |
| 98 | + |
| 99 | +# param_dict = { |
| 100 | +# 'seed': seed, |
| 101 | +# 'data_id': data_id, |
| 102 | +# 'data_size': data_size, |
| 103 | +# 'train_ratio': train_ratio, |
| 104 | +# 'model_id': model_id, |
| 105 | +# 'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), |
| 106 | +# 'embd_dim': embd_dim, |
| 107 | +# 'n_exp': n_exp, |
| 108 | +# 'lr': lr, |
| 109 | +# 'weight_decay':weight_decay |
| 110 | +# } |
| 111 | + |
| 112 | +# ret_dic = train_single_model(param_dict) |
| 113 | + |
| 114 | +# test_loss = np.mean(ret_dic["results"]["test_losses"][-10:]) |
| 115 | + |
| 116 | +# return test_loss |
| 117 | + |
| 118 | +# study = optuna.create_study() |
| 119 | +# study.optimize(loss_objective, n_trials = 15) |
| 120 | +# joblib.dump(study, "wd_lr_study.pkl") |
| 121 | + |
| 122 | +# print(study.best_params) |
| 123 | + |
92 | 124 | # # Train the model |
93 | 125 | # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}, weight decay {weight_decay}") |
94 | 126 | # ret_dic = train_single_model(param_dict) |
|
209 | 241 | seed_list = np.linspace(0, 1000, 20, dtype=int)[7:10] |
210 | 242 | if split == 6: |
211 | 243 | seed_list = np.linspace(0, 1000, 20, dtype=int)[10:13] |
212 | | -if split == 5: |
213 | | - seed_list = np.linspace(0, 1000, 20, dtype=int)[13:17] |
214 | 244 | if split == 7: |
| 245 | + seed_list = np.linspace(0, 1000, 20, dtype=int)[13:17] |
| 246 | +if split == 8: |
215 | 247 | seed_list = np.linspace(0, 1000, 20, dtype=int)[17:] |
216 | 248 |
|
217 | 249 |
|
|
0 commit comments