1+ import torch
2+ import torch .nn as nn
3+ import torch .optim as optim
4+ import numpy as np
5+ import random
6+
7+ from tqdm import tqdm
8+
9+ import sys
10+ sys .path .append (".." )
11+
12+ import argparse
13+ from src .utils .driver import train_single_model
14+ from src .utils .visualization import visualize_embedding
15+ from src .utils .crystal_metric import crystal_metric
16+ import json
17+
18+ data_id_choices = ["lattice" , "greater" , "family_tree" , "equivalence" , "circle" ]
19+ model_id_choices = ["H_MLP" , "standard_MLP" , "H_transformer" , "standard_transformer" ]
20+ if __name__ == '__main__' :
21+ parser = argparse .ArgumentParser (description = 'Experiment' )
22+ parser .add_argument ('--seed' , type = int , default = 77 , help = 'random seed' )
23+ parser .add_argument ('--data_id' , type = str , required = True , choices = data_id_choices , help = 'Data ID' )
24+ parser .add_argument ('--model_id' , type = str , required = True , choices = model_id_choices , help = 'Model ID' )
25+
26+
27+ args = parser .parse_args ()
28+ seed = args .seed
29+ data_id = args .data_id
30+ model_id = args .model_id
31+
32+ data_size = 1000
33+ train_ratio = 0.8
34+
35+ param_dict = {
36+ 'seed' : seed ,
37+ 'data_id' : data_id ,
38+ 'data_size' : data_size ,
39+ 'train_ratio' : train_ratio ,
40+ 'model_id' : model_id ,
41+ 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
42+ 'embd_dim' : 16 ,
43+ }
44+
45+
46+ # Train the model
47+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } " )
48+ ret_dic = train_single_model (param_dict )
49+
50+ ## Exp1: Visualize Embeddings
51+ print (f"Experiment 1: Visualize Embeddings" )
52+ model = ret_dic ['model' ]
53+ dataset = ret_dic ['dataset' ]
54+ torch .save (model .state_dict (), f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
55+
56+ if hasattr (model .embedding , 'weight' ):
57+ visualize_embedding (model .embedding .weight .cpu (), title = f"{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } " , save_path = f"../results/emb_{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .png" , dict_level = dataset ['dict_level' ] if 'dict_level' in dataset else None )
58+ else :
59+ visualize_embedding (model .embedding .data .cpu (), title = f"{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } " , save_path = f"../results/emb_{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .png" , dict_level = dataset ['dict_level' ] if 'dict_level' in dataset else None )
60+
61+
62+ ## Exp2: Metric vs Overall Dataset Size (fixed train-test split)
63+ print (f"Experiment 2: Metric vs Overall Dataset Size (fixed train-test split)" )
64+ data_size_list = [100 , 200 , 500 , 1000 , 2000 , 5000 , 10000 ]
65+ for i in tqdm (range (len (data_size_list ))):
66+ data_size = data_size_list [i ]
67+ param_dict = {
68+ 'seed' : seed ,
69+ 'data_id' : data_id ,
70+ 'data_size' : data_size ,
71+ 'train_ratio' : train_ratio ,
72+ 'model_id' : model_id ,
73+ 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
74+ 'embd_dim' : 16 ,
75+ }
76+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } with train_ratio { train_ratio } and data_size { data_size } " )
77+ ret_dic = train_single_model (param_dict )
78+ model = ret_dic ['model' ]
79+ dataset = ret_dic ['dataset' ]
80+
81+ torch .save (model .state_dict (), f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
82+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _train_results.json" , "w" ) as f :
83+ json .dump (ret_dic ["results" ], f , indent = 4 )
84+
85+ aux_info = {}
86+ if data_id == "lattice" :
87+ aux_info ["lattice_size" ] = 5
88+ elif data_id == "greater" :
89+ aux_info ["p" ] = 30
90+ elif data_id == "family_tree" :
91+ aux_info ["dict_level" ] = dataset ['dict_level' ]
92+ elif data_id == "equivalence" :
93+ aux_info ["mod" ] = 5
94+ elif data_id == "circle" :
95+ aux_info ["p" ] = 59
96+ else :
97+ raise ValueError (f"Unknown data_id: { data_id } " )
98+
99+ if hasattr (model .embedding , 'weight' ):
100+ metric_dict = crystal_metric (model .embedding .weight .cpu (), data_id , aux_info )
101+ else :
102+ metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
103+
104+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .json" , "w" ) as f :
105+ json .dump (metric_dict , f , indent = 4 )
106+
107+ ## Exp3: Metric vs Train Fraction (fixed dataset size)
108+ print (f"Experiment 3: Metric vs Train Fraction (fixed dataset size)" )
109+ train_ratio_list = np .arange (1 , 10 ) / 10
110+ data_size = 1000
111+ for i in tqdm (range (len (train_ratio_list ))):
112+ train_ratio = train_ratio_list [i ]
113+ param_dict = {
114+ 'seed' : seed ,
115+ 'data_id' : data_id ,
116+ 'data_size' : data_size ,
117+ 'train_ratio' : train_ratio ,
118+ 'model_id' : model_id ,
119+ 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
120+ 'embd_dim' : 16 ,
121+ }
122+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } with train_ratio { train_ratio } and data_size { data_size } " )
123+ ret_dic = train_single_model (param_dict )
124+ model = ret_dic ['model' ]
125+ dataset = ret_dic ['dataset' ]
126+
127+ torch .save (model .state_dict (), f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
128+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _train_results.json" , "w" ) as f :
129+ json .dump (ret_dic ["results" ], f , indent = 4 )
130+
131+ aux_info = {}
132+ if data_id == "lattice" :
133+ aux_info ["lattice_size" ] = 5
134+ elif data_id == "greater" :
135+ aux_info ["p" ] = 30
136+ elif data_id == "family_tree" :
137+ aux_info ["dict_level" ] = dataset ['dict_level' ]
138+ elif data_id == "equivalence" :
139+ aux_info ["mod" ] = 5
140+ elif data_id == "circle" :
141+ aux_info ["p" ] = 59
142+ else :
143+ raise ValueError (f"Unknown data_id: { data_id } " )
144+
145+ if hasattr (model .embedding , 'weight' ):
146+ metric_dict = crystal_metric (model .embedding .weight .cpu (), data_id , aux_info )
147+ else :
148+ metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
149+
150+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .json" , "w" ) as f :
151+ json .dump (metric_dict , f , indent = 4 )
152+
153+
154+
155+ ## Exp4: Grokking plot: Run with different seeds
156+ print (f"Experiment 4: Train with different seeds" )
157+ seed_list = np .linspace (0 , 1000 , 20 , dtype = int )
158+ for i in tqdm (range (len (seed_list ))):
159+ seed = seed_list [i ]
160+ data_size = 1000
161+ train_ratio = 0.8
162+
163+ param_dict = {
164+ 'seed' : seed ,
165+ 'data_id' : data_id ,
166+ 'data_size' : data_size ,
167+ 'train_ratio' : train_ratio ,
168+ 'model_id' : model_id ,
169+ 'device' : torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ),
170+ 'embd_dim' : 16 ,
171+ }
172+ print (f"Training model with seed { seed } , data_id { data_id } , model_id { model_id } with train_ratio { train_ratio } and data_size { data_size } " )
173+ ret_dic = train_single_model (param_dict )
174+
175+ model = ret_dic ['model' ]
176+ dataset = ret_dic ['dataset' ]
177+ torch .save (model .state_dict (), f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .pt" )
178+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } _train_results.json" , "w" ) as f :
179+ json .dump (ret_dic ["results" ], f , indent = 4 )
180+
181+ aux_info = {}
182+ if data_id == "lattice" :
183+ aux_info ["lattice_size" ] = 5
184+ elif data_id == "greater" :
185+ aux_info ["p" ] = 30
186+ elif data_id == "family_tree" :
187+ aux_info ["dict_level" ] = dataset ['dict_level' ]
188+ elif data_id == "equivalence" :
189+ aux_info ["mod" ] = 5
190+ elif data_id == "circle" :
191+ aux_info ["p" ] = 59
192+ else :
193+ raise ValueError (f"Unknown data_id: { data_id } " )
194+
195+ if hasattr (model .embedding , 'weight' ):
196+ metric_dict = crystal_metric (model .embedding .weight .cpu (), data_id , aux_info )
197+ else :
198+ metric_dict = crystal_metric (model .embedding .data .cpu (), data_id , aux_info )
199+
200+ with open (f"../results/{ seed } _{ data_id } _{ model_id } _{ data_size } _{ train_ratio } .json" , "w" ) as f :
201+ json .dump (metric_dict , f , indent = 4 )
0 commit comments