Skip to content

Commit 54e617e

Browse files
committed
New experiment testing exponent, saving config files, option to save images during training to make a video
1 parent 5dc16aa commit 54e617e

File tree

6 files changed

+363
-169
lines changed

6 files changed

+363
-169
lines changed

notebooks/permutation_group.ipynb

Lines changed: 231 additions & 123 deletions
Large diffs are not rendered by default.

src/run_exp.py

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from src.utils.crystal_metric import crystal_metric
1616
import json
1717

18+
import os
19+
from datetime import datetime
20+
1821
data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle", "permutation"]
1922
model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"]
2023
if __name__ == '__main__':
@@ -23,15 +26,21 @@
2326
parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID')
2427
parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID')
2528

26-
results_root = "results_embd_n"
27-
2829
args = parser.parse_args()
2930
seed = args.seed
3031
data_id = args.data_id
3132
model_id = args.model_id
3233

34+
## ------------------------ CONFIG -------------------------- ##
35+
3336
data_size = 1000
3437
train_ratio = 0.8
38+
embd_dim = 16
39+
40+
lr = 0.002
41+
weight_decay = 0.01
42+
43+
n_exp=embd_dim
3544

3645
param_dict = {
3746
'seed': seed,
@@ -40,9 +49,24 @@
4049
'train_ratio': train_ratio,
4150
'model_id': model_id,
4251
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
43-
'embd_dim': 16,
52+
'embd_dim': embd_dim,
53+
'n_exp': n_exp,
54+
'lr': lr,
55+
'weight_decay':weight_decay
4456
}
4557

58+
results_root = "../results_test"
59+
60+
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
61+
results_root = f"{results_root}/{current_datetime}"
62+
os.mkdir(results_root)
63+
64+
param_dict_json = {k: v for k, v in param_dict.items() if k != 'device'} # since torch.device is not JSON serializable
65+
66+
67+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_config.json", "w") as f:
68+
json.dump(param_dict_json, f, indent=4)
69+
4670
aux_info = {}
4771
if data_id == "lattice":
4872
aux_info["lattice_size"] = 5
@@ -59,20 +83,20 @@
5983
else:
6084
raise ValueError(f"Unknown data_id: {data_id}")
6185

62-
# # Train the model
63-
# print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}")
64-
# ret_dic = train_single_model(param_dict)
86+
# Train the model
87+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
88+
ret_dic = train_single_model(param_dict)
6589

66-
# ## Exp1: Visualize Embeddings
67-
# print(f"Experiment 1: Visualize Embeddings")
68-
# model = ret_dic['model']
69-
# dataset = ret_dic['dataset']
70-
# torch.save(model.state_dict(), f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
90+
## Exp1: Visualize Embeddings
91+
print(f"Experiment 1: Visualize Embeddings")
92+
model = ret_dic['model']
93+
dataset = ret_dic['dataset']
94+
torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt")
7195

72-
# if hasattr(model.embedding, 'weight'):
73-
# visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
74-
# else:
75-
# visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
96+
if hasattr(model.embedding, 'weight'):
97+
visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}", save_path=f"{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
98+
else:
99+
visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}", save_path=f"{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
76100

77101

78102
# ## Exp2: Metric vs Overall Dataset Size (fixed train-test split)
@@ -87,15 +111,19 @@
87111
# 'train_ratio': train_ratio,
88112
# 'model_id': model_id,
89113
# 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
90-
# 'embd_dim': 16,
114+
# 'embd_dim': embd_dim,
115+
# 'n_exp': n_exp,
116+
# 'lr': lr,
117+
# 'weight_decay':weight_decay
91118
# }
92-
# print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
119+
120+
# print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
93121
# ret_dic = train_single_model(param_dict)
94122
# model = ret_dic['model']
95123
# dataset = ret_dic['dataset']
96124

97-
# torch.save(model.state_dict(), f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
98-
# with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
125+
# torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt")
126+
# with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f:
99127
# json.dump(ret_dic["results"], f, indent=4)
100128

101129
# if data_id == "family_tree":
@@ -106,7 +134,7 @@
106134
# else:
107135
# metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
108136

109-
# with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
137+
# with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f:
110138
# json.dump(metric_dict, f, indent=4)
111139

112140
## Exp3: Metric vs Train Fraction (fixed dataset size)
@@ -122,15 +150,18 @@
122150
'train_ratio': train_ratio,
123151
'model_id': model_id,
124152
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
125-
'embd_dim': 16,
153+
'embd_dim': embd_dim,
154+
'n_exp': n_exp,
155+
'lr': lr,
156+
'weight_decay':weight_decay
126157
}
127-
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
158+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
128159
ret_dic = train_single_model(param_dict)
129160
model = ret_dic['model']
130161
dataset = ret_dic['dataset']
131162

132-
torch.save(model.state_dict(), f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
133-
with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
163+
torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt")
164+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f:
134165
json.dump(ret_dic["results"], f, indent=4)
135166

136167
if data_id == "family_tree":
@@ -141,7 +172,7 @@
141172
else:
142173
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
143174

144-
with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
175+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_metric.json", "w") as f:
145176
json.dump(metric_dict, f, indent=4)
146177

147178
## Exp4: Grokking plot: Run with different seeds
@@ -160,14 +191,57 @@
160191
'train_ratio': train_ratio,
161192
'model_id': model_id,
162193
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
163-
'embd_dim': 16,
194+
'embd_dim': embd_dim,
195+
'n_exp': n_exp,
196+
'lr': lr,
197+
'weight_decay':weight_decay
198+
}
199+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
200+
ret_dic = train_single_model(param_dict)
201+
model = ret_dic['model']
202+
dataset = ret_dic['dataset']
203+
torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt")
204+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f:
205+
json.dump(ret_dic["results"], f, indent=4)
206+
207+
if data_id == "family_tree":
208+
aux_info["dict_level"] = dataset['dict_level']
209+
210+
if hasattr(model.embedding, 'weight'):
211+
metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info)
212+
else:
213+
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
214+
215+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f:
216+
json.dump(metric_dict, f, indent=4)
217+
218+
#Exp5: N Exponent value plot: Run with different n values, plot test accuracy vs. and explained variance vs.
219+
220+
print(f"Experiment 5: Train with different exponent values")
221+
n_list = np.arange(1, 17, dtype=int)
222+
223+
for i in tqdm(range(len(n_list))):
224+
n_exp = n_list[i]
225+
data_size = 1000
226+
train_ratio = 0.8
227+
228+
param_dict = {
229+
'seed': seed,
230+
'data_id': data_id,
231+
'data_size': data_size,
232+
'train_ratio': train_ratio,
233+
'model_id': model_id,
234+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
235+
'embd_dim': embd_dim,
236+
'n_exp': n_exp
164237
}
165-
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
238+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
239+
166240
ret_dic = train_single_model(param_dict)
167241
model = ret_dic['model']
168242
dataset = ret_dic['dataset']
169-
torch.save(model.state_dict(), f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
170-
with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
243+
torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt")
244+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f:
171245
json.dump(ret_dic["results"], f, indent=4)
172246

173247
if data_id == "family_tree":
@@ -178,6 +252,6 @@
178252
else:
179253
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
180254

181-
with open(f"../{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
255+
with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f:
182256
json.dump(metric_dict, f, indent=4)
183-
257+

src/utils/dataset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def permutation_group_dataset(p, num, seed=0, device='cpu'):
6969
torch.manual_seed(seed)
7070
np.random.seed(seed)
7171

72-
perms = list(itertools.permutations(range(4)))
72+
perms = list(itertools.permutations(range(p)))
7373
num_perms = len(perms)
7474

7575
perm_dict = dict(enumerate(perms))
@@ -79,13 +79,9 @@ def permutation_group_dataset(p, num, seed=0, device='cpu'):
7979

8080
data_id = [[perms[int(i)], perms[int(j)]] for i, j in torch.cartesian_prod(idx, idx)]
8181
keyed_data_id = np.array([[swapped_dict[data_id[i][0]], swapped_dict[data_id[i][1]]] for i in range(len(data_id))])
82-
# data_id = np.fromiter([[tuple(perms[i]), tuple(perms[j])] for i, j in zip(idx1, idx2)], object)
83-
# data_id = np.array([[perms_list[int(i)], perms_list[int(j)]] for i, j in torch.cartesian_prod(idx, idx)])
8482

8583
labels = [tuple(np.array(perms[int(i)])[np.array(perms[int(j)])]) for i, j in torch.cartesian_prod(idx, idx)]
8684
keyed_labels = np.array([swapped_dict[labels[i]] for i in range(len(labels))])
87-
# labels = [sum(a != b for a, b in zip(lbl, idx)) for lbl in labels]
88-
# labels = np.array([sum(math.pow(10, i) * num for i, num in enumerate(reversed(tup))) for tup in labels]).astype(int)
8985
labels = torch.tensor(labels, dtype=torch.long, device=device)
9086

9187
perm_vals = ["".join(np.array(perm_dict[i]).astype(str)) for i in range(len(perm_dict))]

src/utils/driver.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def train_single_model(param_dict: dict):
4747
raise ValueError("device must be provided in param_dict")
4848
if "embd_dim" not in param_dict:
4949
raise ValueError("embd_dim must be provided in param_dict")
50+
if "n_exp" not in param_dict:
51+
raise ValueError("n_exp must be provided in param_dict")
52+
5053

5154
seed = param_dict['seed']
5255
data_id = param_dict['data_id']
@@ -55,13 +58,17 @@ def train_single_model(param_dict: dict):
5558
model_id = param_dict['model_id']
5659
device = param_dict['device']
5760
embd_dim = param_dict['embd_dim']
61+
n_exp = param_dict['n_exp']
62+
63+
video = False if 'video' not in param_dict else param_dict['video']
64+
lr = 0.002 if 'lr' not in param_dict else param_dict['lr']
65+
weight_decay = 0.01 if 'weight_decay' not in param_dict else param_dict['weight_decay']
5866

5967
set_seed(seed)
6068

61-
6269
# define dataset
6370
input_token = 2
64-
num_epochs = None
71+
num_epochs = 7000
6572
if data_id == "lattice":
6673
dataset = parallelogram_dataset(p=5, dim=2, num=data_size, seed=seed, device=device)
6774
input_token = 3
@@ -77,28 +84,27 @@ def train_single_model(param_dict: dict):
7784
elif data_id=="permutation":
7885
dataset = permutation_group_dataset(p=4, num=data_size, seed=seed, device=device)
7986
if model_id == "H_transformer" or model_id == "standard_transformer":
80-
num_epochs = 12750 # extra epochs to train fully
87+
num_epochs = 10000 # extra epochs to train fully
8188
else:
8289
raise ValueError(f"Unknown data_id: {data_id}")
8390

8491
dataset = split_dataset(dataset, train_ratio=train_ratio, seed=seed)
8592
vocab_size = dataset['vocab_size']
8693

87-
8894
# define model
8995
if model_id == "H_MLP":
9096
weight_tied = True
9197
hidden_size = 100
9298
shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size]
93-
model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=np.embd_dim, init_scale=1).to(device)
99+
model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=n_exp, init_scale=1).to(device)
94100
elif model_id == "standard_MLP":
95101
unembd = True
96102
weight_tied = True
97103
hidden_size = 100
98104
shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size]
99105
model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed, init_scale=1).to(device)
100106
elif model_id == "H_transformer":
101-
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, n_dist=embd_dim,seq_len=input_token, seed=seed, use_dist_layer=True, init_scale=1).to(device)
107+
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, n_dist=n_exp,seq_len=input_token, seed=seed, use_dist_layer=True, init_scale=1).to(device)
102108
elif model_id == "standard_transformer":
103109
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, seq_len=input_token, seed=seed, use_dist_layer=False, init_scale=1).to(device)
104110
else:
@@ -112,7 +118,7 @@ def train_single_model(param_dict: dict):
112118
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
113119

114120
ret_dic = {}
115-
ret_dic["results"] = model.train(param_dict={'num_epochs': num_epochs if num_epochs else 7000, 'learning_rate': 0.002, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device})
121+
ret_dic["results"] = model.train(param_dict={'num_epochs': num_epochs, 'learning_rate': lr, 'weight_decay':weight_decay, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device, 'video': video})
116122
ret_dic["model"] = model
117123
ret_dic["dataset"] = dataset
118124

src/utils/model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import math
77
from src.utils.dataset import *
8+
from src.utils.visualization import *
89

910
import sys
1011
# import keyboard
@@ -23,7 +24,8 @@ def train(self, param_dict: dict):
2324
train_dataloader = param_dict['train_dataloader']
2425
test_dataloader = param_dict['test_dataloader']
2526
device = param_dict['device']
26-
weight_decay = 0.01 if 'weight_decay' not in param_dict else param_dict['weight_decay']
27+
weight_decay = param_dict['weight_decay']
28+
video = False if 'video' not in param_dict else param_dict['video']
2729

2830
verbose = True
2931
if 'verbose' in param_dict:
@@ -45,6 +47,12 @@ def train(self, param_dict: dict):
4547
# if keyboard.is_pressed('ctrl+d'):
4648
# print("Manual early stopping occurring.")
4749
# break
50+
if video and epoch%10 == 0: # save every 10 epochs
51+
if hasattr(self.embedding, 'weight'):
52+
embd = self.embedding.weight
53+
else:
54+
embd = self.embedding.data
55+
visualize_embedding(embd, title=f"Epoch {epoch}", save_path=f"../video_imgs/{epoch}.png", dict_level = None, color_dict = True, adjust_overlapping_text = False)
4856

4957
train_loss = 0
5058
train_correct = 0
@@ -305,7 +313,7 @@ def forward(self, x):
305313
return logits
306314

307315

308-
def load_model_from_file(model_id, data_id, data_size = 1000, train_ratio=0.8,seed=66, embd_dim=16, device='cpu'):
316+
def load_model_from_file(model_id, data_id, results_root = "results",data_size = 1000, train_ratio=0.8,seed=66, embd_dim=16, device='cpu'):
309317

310318
input_token=2
311319

@@ -348,6 +356,6 @@ def load_model_from_file(model_id, data_id, data_size = 1000, train_ratio=0.8,se
348356
else:
349357
raise ValueError(f"Unknown model_id: {model_id}")
350358

351-
model.load_state_dict(torch.load(f"../results/{seed}_permutation_{model_id}_{data_size}_{train_ratio}.pt"))
359+
model.load_state_dict(torch.load(f"../{results_root}/{seed}_permutation_{model_id}_{data_size}_{train_ratio}.pt"))
352360

353361
return model

src/utils/visualization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def visualize_embedding(emb, title="", save_path=None, dict_level = None, color_
4343
adjust_text(texts, x=x, y=y, autoalign='xy', force_points=0.5, only_move = {'text':'xy'})
4444
if save_path:
4545
plt.savefig(save_path)
46+
plt.show()
47+
plt.close()
4648

4749

4850
def visualize_embedding_3d(emb, title="", save_path=None, dict_level = None, color_dict=True):

0 commit comments

Comments
 (0)