1515from src .utils .crystal_metric import crystal_metric
1616import json
1717
18+ import os
19+ from datetime import datetime
20+
1821data_id_choices = ["lattice" , "greater" , "family_tree" , "equivalence" , "circle" , "permutation" ]
1922model_id_choices = ["H_MLP" , "standard_MLP" , "H_transformer" , "standard_transformer" ]
2023if __name__ == '__main__' :
2326 parser .add_argument ('--data_id' , type = str , required = True , choices = data_id_choices , help = 'Data ID' )
2427 parser .add_argument ('--model_id' , type = str , required = True , choices = model_id_choices , help = 'Model ID' )
2528
26- results_root = "results_embd_n"
27-
2829args = parser .parse_args ()
2930seed = args .seed
3031data_id = args .data_id
3132model_id = args .model_id
3233
34+ ## ------------------------ CONFIG -------------------------- ##
35+
3336data_size = 1000
3437train_ratio = 0.8
38+ embd_dim = 16
39+
40+ lr = 0.002
41+ weight_decay = 0.01
42+
43+ n_exp = embd_dim
3544
3645param_dict = {
3746 'seed' : seed ,
4049 'train_ratio' : train_ratio ,
4150 'model_id' : model_id ,
4251 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
43- 'embd_dim' : 16 ,
52+ 'embd_dim' : embd_dim ,
53+ 'n_exp' : n_exp ,
54+ 'lr' : lr ,
55+ 'weight_decay' :weight_decay
4456}
4557
58+ results_root = "../results_test"
59+
60+ current_datetime = datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
61+ results_root = f"{ results_root } /{ current_datetime } "
62+ os .mkdir (results_root )
63+
64+ param_dict_json = {k : v for k , v in param_dict .items () if k != 'device' } # since torch.device is not JSON serializable
65+
66+
67+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _{ n_exp } _config.json" , "w" ) as f :
68+ json .dump (param_dict_json , f , indent = 4 )
69+
4670aux_info = {}
4771if data_id == "lattice" :
4872 aux_info ["lattice_size" ] = 5
5983else :
6084 raise ValueError (f"Unknown data_id: { data_id } " )
6185
62- # # Train the model
63- # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}")
64- # ret_dic = train_single_model(param_dict)
86+ # Train the model
87+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } , n_exp { n_exp } , embd_dim { embd_dim } " )
88+ ret_dic = train_single_model (param_dict )
6589
66- # # # Exp1: Visualize Embeddings
67- # print(f"Experiment 1: Visualize Embeddings")
68- # model = ret_dic['model']
69- # dataset = ret_dic['dataset']
70- # torch.save(model.state_dict(), f"../ {results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
90+ ## Exp1: Visualize Embeddings
91+ print (f"Experiment 1: Visualize Embeddings" )
92+ model = ret_dic ['model' ]
93+ dataset = ret_dic ['dataset' ]
94+ torch .save (model .state_dict (), f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .pt" )
7195
72- # if hasattr(model.embedding, 'weight'):
73- # visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../ {results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
74- # else:
75- # visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../ {results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False)
96+ if hasattr (model .embedding , 'weight' ):
97+ visualize_embedding (model .embedding .weight .cpu (), title = f"{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } " , save_path = f"{ results_root } /emb_{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .png" , dict_level = dataset ['dict_level' ] if 'dict_level' in dataset else None , color_dict = False if data_id == "permutation" else True , adjust_overlapping_text = False )
98+ else :
99+ visualize_embedding (model .embedding .data .cpu (), title = f"{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } " , save_path = f"{ results_root } /emb_{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .png" , dict_level = dataset ['dict_level' ] if 'dict_level' in dataset else None , color_dict = False if data_id == "permutation" else True , adjust_overlapping_text = False )
76100
77101
78102# ## Exp2: Metric vs Overall Dataset Size (fixed train-test split)
87111# 'train_ratio': train_ratio,
88112# 'model_id': model_id,
89113# 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
90- # 'embd_dim': 16,
114+ # 'embd_dim': embd_dim,
115+ # 'n_exp': n_exp,
116+ # 'lr': lr,
117+ # 'weight_decay':weight_decay
91118# }
92- # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id} with train_ratio {train_ratio} and data_size {data_size}")
119+
120+ # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}")
93121# ret_dic = train_single_model(param_dict)
94122# model = ret_dic['model']
95123# dataset = ret_dic['dataset']
96124
97- # torch.save(model.state_dict(), f"../ {results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
98- # with open(f"../ {results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
125+ # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp }.pt")
126+ # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp }_train_results.json", "w") as f:
99127# json.dump(ret_dic["results"], f, indent=4)
100128
101129# if data_id == "family_tree":
106134# else:
107135# metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info)
108136
109- # with open(f"../ {results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.json", "w") as f:
137+ # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp }.json", "w") as f:
110138# json.dump(metric_dict, f, indent=4)
111139
112140## Exp3: Metric vs Train Fraction (fixed dataset size)
122150 'train_ratio' : train_ratio ,
123151 'model_id' : model_id ,
124152 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
125- 'embd_dim' : 16 ,
153+ 'embd_dim' : embd_dim ,
154+ 'n_exp' : n_exp ,
155+ 'lr' : lr ,
156+ 'weight_decay' :weight_decay
126157 }
127- print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } with train_ratio { train_ratio } and data_size { data_size } " )
158+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } , n_exp { n_exp } , embd_dim { embd_dim } " )
128159 ret_dic = train_single_model (param_dict )
129160 model = ret_dic ['model' ]
130161 dataset = ret_dic ['dataset' ]
131162
132- torch .save (model .state_dict (), f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
133- with open (f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _train_results.json" , "w" ) as f :
163+ torch .save (model .state_dict (), f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .pt" )
164+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } _train_results.json" , "w" ) as f :
134165 json .dump (ret_dic ["results" ], f , indent = 4 )
135166
136167 if data_id == "family_tree" :
141172 else :
142173 metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
143174
144- with open (f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .json" , "w" ) as f :
175+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } _metric .json" , "w" ) as f :
145176 json .dump (metric_dict , f , indent = 4 )
146177
147178## Exp4: Grokking plot: Run with different seeds
160191 'train_ratio' : train_ratio ,
161192 'model_id' : model_id ,
162193 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
163- 'embd_dim' : 16 ,
194+ 'embd_dim' : embd_dim ,
195+ 'n_exp' : n_exp ,
196+ 'lr' : lr ,
197+ 'weight_decay' :weight_decay
198+ }
199+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } , n_exp { n_exp } , embd_dim { embd_dim } " )
200+ ret_dic = train_single_model (param_dict )
201+ model = ret_dic ['model' ]
202+ dataset = ret_dic ['dataset' ]
203+ torch .save (model .state_dict (), f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _{ n_exp } .pt" )
204+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _{ n_exp } _train_results.json" , "w" ) as f :
205+ json .dump (ret_dic ["results" ], f , indent = 4 )
206+
207+ if data_id == "family_tree" :
208+ aux_info ["dict_level" ] = dataset ['dict_level' ]
209+
210+ if hasattr (model .embedding , 'weight' ):
211+ metric_dict = crystal_metric (model .embedding .weight .cpu ().detach (), data_id , aux_info )
212+ else :
213+ metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
214+
215+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _{ n_exp } .json" , "w" ) as f :
216+ json .dump (metric_dict , f , indent = 4 )
217+
218+ #Exp5: N Exponent value plot: Run with different n values, plot test accuracy vs. and explained variance vs.
219+
220+ print (f"Experiment 5: Train with different exponent values" )
221+ n_list = np .arange (1 , 17 , dtype = int )
222+
223+ for i in tqdm (range (len (n_list ))):
224+ n_exp = n_list [i ]
225+ data_size = 1000
226+ train_ratio = 0.8
227+
228+ param_dict = {
229+ 'seed' : seed ,
230+ 'data_id' : data_id ,
231+ 'data_size' : data_size ,
232+ 'train_ratio' : train_ratio ,
233+ 'model_id' : model_id ,
234+ 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
235+ 'embd_dim' : embd_dim ,
236+ 'n_exp' : n_exp
164237 }
165- print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } with train_ratio { train_ratio } and data_size { data_size } " )
238+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } , n_exp { n_exp } , embd_dim { embd_dim } " )
239+
166240 ret_dic = train_single_model (param_dict )
167241 model = ret_dic ['model' ]
168242 dataset = ret_dic ['dataset' ]
169- torch .save (model .state_dict (), f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
170- with open (f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _train_results.json" , "w" ) as f :
243+ torch .save (model .state_dict (), f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .pt" )
244+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } _train_results.json" , "w" ) as f :
171245 json .dump (ret_dic ["results" ], f , indent = 4 )
172246
173247 if data_id == "family_tree" :
178252 else :
179253 metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
180254
181- with open (f"../ { results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .json" , "w" ) as f :
255+ with open (f"{ results_root } /{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _ { n_exp } .json" , "w" ) as f :
182256 json .dump (metric_dict , f , indent = 4 )
183-
257+
0 commit comments