Skip to content

Commit 659b4e9

Browse files
committed
Added automatic script for running all toy experiments (run_exp.py)
1 parent 70971e0 commit 659b4e9

File tree

8 files changed

+444
-53
lines changed

8 files changed

+444
-53
lines changed

scripts/lattice.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import time
22
import os
33
import sys
4-
from model import *
5-
from dataset import *
4+
5+
sys.path.append('..')
6+
7+
from src.utils.model import *
8+
from src.utils.dataset import *
69
import numpy as np
710
from sklearn.decomposition import PCA
811

scripts/modadd.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import time
22
import os
33
import sys
4-
from model import *
5-
from dataset import *
4+
5+
sys.path.append('..')
6+
7+
from src.utils.model import *
8+
from src.utils.dataset import *
69
import numpy as np
710
from sklearn.decomposition import PCA
811
import math

src/run_exp.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import numpy as np
5+
import random
6+
7+
from tqdm import tqdm
8+
9+
import sys
10+
sys.path.append("..")
11+
12+
import argparse
13+
from src.utils.driver import train_single_model
14+
from src.utils.visualization import visualize_embedding
15+
from src.utils.crystal_metric import crystal_metric
16+
import json
17+
18+
data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle"]
19+
model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"]
20+
if __name__ == '__main__':
21+
parser = argparse.ArgumentParser(description='Experiment')
22+
parser.add_argument('--seed', type=int, default=77, help='random seed')
23+
parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID')
24+
parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID')
25+
26+
27+
args = parser.parse_args()
28+
seed = args.seed
29+
data_id = args.data_id
30+
model_id = args.model_id
31+
32+
data_size = 1000
33+
train_ratio = 0.8
34+
35+
param_dict = {
36+
'seed': seed,
37+
'data_id': data_id,
38+
'data_size': data_size,
39+
'train_ratio': train_ratio,
40+
'model_id': model_id,
41+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
42+
'embd_dim': 16,
43+
}
44+
45+
46+
# Train the model
47+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}")
48+
ret_dic = train_single_model(param_dict)
49+
50+
## Exp1: Visualize Embeddings
51+
print(f"Experiment 1: Visualize Embeddings")
52+
model = ret_dic['model']
53+
dataset = ret_dic['dataset']
54+
torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
55+
56+
if hasattr(model.embedding, 'weight'):
57+
visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../results/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None)
58+
else:
59+
visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../results/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None)
60+
61+
62+
## Exp2: Metric vs Overall Dataset Size (fixed train-test split)
63+
print(f"Experiment 2: Metric vs Overall Dataset Size (fixed train-test split)")
64+
data_size_list = [100, 200, 500, 1000, 2000, 5000, 10000]
65+
for i in tqdm(range(len(data_size_list))):
66+
data_size = data_size_list[i]
67+
param_dict = {
68+
'seed': seed,
69+
'data_id': data_id,
70+
'data_size': data_size,
71+
'train_ratio': train_ratio,
72+
'model_id': model_id,
73+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
74+
'embd_dim': 16,
75+
}
76+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
77+
ret_dic = train_single_model(param_dict)
78+
model = ret_dic['model']
79+
dataset = ret_dic['dataset']
80+
81+
torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
82+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
83+
json.dump(ret_dic["results"], f, indent=4)
84+
85+
aux_info = {}
86+
if data_id == "lattice":
87+
aux_info["lattice_size"] = 5
88+
elif data_id == "greater":
89+
aux_info["p"] = 30
90+
elif data_id == "family_tree":
91+
aux_info["dict_level"] = dataset['dict_level']
92+
elif data_id == "equivalence":
93+
aux_info["mod"] = 5
94+
elif data_id == "circle":
95+
aux_info["p"] = 59
96+
else:
97+
raise ValueError(f"Unknown data_id: {data_id}")
98+
99+
if hasattr(model.embedding, 'weight'):
100+
metric_dict = crystal_metric(model.embedding.weight.cpu(), data_id, aux_info)
101+
else:
102+
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
103+
104+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
105+
json.dump(metric_dict, f, indent=4)
106+
107+
## Exp3: Metric vs Train Fraction (fixed dataset size)
108+
print(f"Experiment 3: Metric vs Train Fraction (fixed dataset size)")
109+
train_ratio_list = np.arange(1, 10) / 10
110+
data_size = 1000
111+
for i in tqdm(range(len(train_ratio_list))):
112+
train_ratio = train_ratio_list[i]
113+
param_dict = {
114+
'seed': seed,
115+
'data_id': data_id,
116+
'data_size': data_size,
117+
'train_ratio': train_ratio,
118+
'model_id': model_id,
119+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
120+
'embd_dim': 16,
121+
}
122+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
123+
ret_dic = train_single_model(param_dict)
124+
model = ret_dic['model']
125+
dataset = ret_dic['dataset']
126+
127+
torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
128+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
129+
json.dump(ret_dic["results"], f, indent=4)
130+
131+
aux_info = {}
132+
if data_id == "lattice":
133+
aux_info["lattice_size"] = 5
134+
elif data_id == "greater":
135+
aux_info["p"] = 30
136+
elif data_id == "family_tree":
137+
aux_info["dict_level"] = dataset['dict_level']
138+
elif data_id == "equivalence":
139+
aux_info["mod"] = 5
140+
elif data_id == "circle":
141+
aux_info["p"] = 59
142+
else:
143+
raise ValueError(f"Unknown data_id: {data_id}")
144+
145+
if hasattr(model.embedding, 'weight'):
146+
metric_dict = crystal_metric(model.embedding.weight.cpu(), data_id, aux_info)
147+
else:
148+
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
149+
150+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
151+
json.dump(metric_dict, f, indent=4)
152+
153+
154+
155+
## Exp4: Grokking plot: Run with different seeds
156+
print(f"Experiment 4: Train with different seeds")
157+
seed_list = np.linspace(0, 1000, 20, dtype=int)
158+
for i in tqdm(range(len(seed_list))):
159+
seed = seed_list[i]
160+
data_size = 1000
161+
train_ratio = 0.8
162+
163+
param_dict = {
164+
'seed': seed,
165+
'data_id': data_id,
166+
'data_size': data_size,
167+
'train_ratio': train_ratio,
168+
'model_id': model_id,
169+
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
170+
'embd_dim': 16,
171+
}
172+
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
173+
ret_dic = train_single_model(param_dict)
174+
175+
model = ret_dic['model']
176+
dataset = ret_dic['dataset']
177+
torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
178+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
179+
json.dump(ret_dic["results"], f, indent=4)
180+
181+
aux_info = {}
182+
if data_id == "lattice":
183+
aux_info["lattice_size"] = 5
184+
elif data_id == "greater":
185+
aux_info["p"] = 30
186+
elif data_id == "family_tree":
187+
aux_info["dict_level"] = dataset['dict_level']
188+
elif data_id == "equivalence":
189+
aux_info["mod"] = 5
190+
elif data_id == "circle":
191+
aux_info["p"] = 59
192+
else:
193+
raise ValueError(f"Unknown data_id: {data_id}")
194+
195+
if hasattr(model.embedding, 'weight'):
196+
metric_dict = crystal_metric(model.embedding.weight.cpu(), data_id, aux_info)
197+
else:
198+
metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
199+
200+
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
201+
json.dump(metric_dict, f, indent=4)

src/utils/crystal_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def side_length_deviation(a, b, c, d):
7373

7474
metric_dict = {
7575
'metric': np.mean(deviation_arr),
76-
'variances': variances,
76+
'variances': variances.tolist(),
7777
}
7878

7979
return metric_dict
@@ -93,7 +93,7 @@ def greater_metric(reps, aux_info):
9393

9494
metric_dict = {
9595
'metric': np.std(diff_arr) / np.mean(diff_arr),
96-
'variances': variances,
96+
'variances': variances.tolist(),
9797
}
9898
return metric_dict
9999

@@ -135,7 +135,7 @@ def family_tree_metric(reps, aux_info):
135135

136136
metric_dict = {
137137
'metric': 1 - np.mean([collinearity for collinearity in collinearity_by_generation.values() if not np.isnan(collinearity)]),
138-
'variances': variances,
138+
'variances': variances.tolist(),
139139
}
140140
return metric_dict
141141

@@ -160,7 +160,7 @@ def equivalence_metric(reps, aux_info):
160160
print(np.mean(diff_arr) , np.mean(cross_diff_arr))
161161
metric_dict = {
162162
'metric': np.mean(diff_arr) / np.mean(cross_diff_arr),
163-
'variances': variances,
163+
'variances': variances.tolist(),
164164
}
165165
return metric_dict
166166

@@ -187,6 +187,6 @@ def circle_metric(reps, aux_info):
187187

188188
metric_dict = {
189189
'metric': circularity_score,
190-
'variances': variances,
190+
'variances': variances.tolist(),
191191
}
192192
return metric_dict

src/utils/dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import numpy as np
22
import torch
33

4+
import sys
5+
sys.path.append("..")
6+
from src.utils.FamilyTreeGenerator import GenerateFamilyTree
7+
48
def parallelogram_dataset(p, dim, num, seed=0, device='cpu'):
59

610
torch.manual_seed(seed)
@@ -34,7 +38,7 @@ def parallelogram_dataset(p, dim, num, seed=0, device='cpu'):
3438
return dataset
3539

3640

37-
def modular_addition_dataset(p, seed=0, device='cpu'):
41+
def modular_addition_dataset(p, num, seed=0, device='cpu'):
3842

3943
torch.manual_seed(seed)
4044
np.random.seed(seed)
@@ -43,8 +47,11 @@ def modular_addition_dataset(p, seed=0, device='cpu'):
4347
y = np.arange(p)
4448
XX, YY = np.meshgrid(x, y)
4549
data_id = np.transpose([XX.reshape(-1,), YY.reshape(-1,)])
50+
51+
data_id = np.random.choice(len(data_id), size=num, replace=True)
4652
labels = (data_id[:,0] + data_id[:,1]) % p
4753
labels = torch.tensor(labels, dtype=torch.long)
54+
4855

4956
vocab_size = p
5057

@@ -281,7 +288,6 @@ def mod_classification_dataset(p, num, seed=0, device='cpu'):
281288

282289
return dataset
283290

284-
from FamilyTreeGenerator import GenerateFamilyTree
285291
def family_tree_dataset(p, num, seed=0, device='cpu'):
286292

287293
torch.manual_seed(seed)
@@ -311,5 +317,6 @@ def family_tree_dataset(p, num, seed=0, device='cpu'):
311317
dataset['data_id'] = data_id
312318
dataset['label'] = labels
313319
dataset['vocab_size'] = vocab_size
320+
dataset['dict_level'] = ret_dic['dict_level']
314321

315322
return dataset

0 commit comments

Comments
 (0)