77
88from tqdm import tqdm
99
10+ from itertools import combinations
11+ from sklearn .decomposition import PCA
12+
1013class MLP (nn .Module ):
1114 def __init__ (self , shp , vocab_size , embd_dim , input_token = 2 , init_scale = 1. , unembd = False , weight_tied = False , seed = 0 ):
1215 super (MLP , self ).__init__ ()
@@ -169,17 +172,21 @@ def forward(self, x):
169172 else :
170173 logits = self .fc (x [:, - 1 ]) # Only predict the last token
171174 return logits
175+
172176 def train (self , param_dict : dict ):
173177
174178 num_epochs = param_dict ['num_epochs' ]
175179 learning_rate = param_dict ['learning_rate' ]
176180 dataloader = param_dict ['dataloader' ]
181+ device = param_dict ['device' ]
177182 criterion = nn .CrossEntropyLoss ()
178183
179184 optimizer = optim .AdamW (self .parameters (), lr = learning_rate )
180185 for epoch in tqdm (range (num_epochs )):
181186 total_loss = 0
182187 for batch_inputs , batch_targets in dataloader :
188+ batch_inputs = batch_inputs .to (device )
189+ batch_targets = batch_targets .to (device )
183190 optimizer .zero_grad ()
184191 logits = self .forward (batch_inputs )
185192
@@ -189,4 +196,61 @@ def train(self, param_dict: dict):
189196 total_loss += loss .item ()
190197
191198 if (epoch + 1 ) % 50 == 0 :
192- print (f"Epoch { epoch + 1 } /{ num_epochs } , Loss: { total_loss / len (dataloader ):.4f} " )
199+ print (f"Epoch { epoch + 1 } /{ num_epochs } , Loss: { total_loss / len (dataloader ):.4f} " )
200+
201+
202+ def eval (self ):
203+ deviation_arr = []
204+ points = [(i , j ) for i in range (5 ) for j in range (5 )]
205+
206+
207+ def side_length_deviation (a , b , c , d ):
208+ a , b , c , d = np .array (a ), np .array (b ), np .array (c ), np .array (d )
209+
210+ # Compute lengths of opposite sides
211+ length_ab = np .linalg .norm (b - a )
212+ length_cd = np .linalg .norm (d - c )
213+ length_ac = np .linalg .norm (c - a )
214+ length_bd = np .linalg .norm (b - d )
215+ length_bc = np .linalg .norm (c - b )
216+ length_ad = np .linalg .norm (d - a )
217+
218+ # Calculate side length deviation
219+ side_deviation = np .sqrt ((length_ab - length_cd )** 2 + (length_ac - length_bd )** 2 ) / np .sqrt ((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2 )/ 2 )
220+
221+ return side_deviation
222+
223+ for quad in combinations (points , 3 ):
224+ a , b , c = quad
225+ d = (c [0 ] + b [0 ] - a [0 ], c [1 ] + b [1 ] - a [1 ])
226+ if d [0 ] < 0 or d [0 ] >= 5 or d [1 ] < 0 or d [1 ] >= 5 :
227+ continue
228+
229+ if a [0 ] == b [0 ] and b [0 ] == c [0 ]:
230+ continue
231+ if a [1 ] == b [1 ] and b [1 ] == c [1 ]:
232+ continue
233+
234+ a = 5 * a [0 ] + a [1 ]
235+ b = 5 * b [0 ] + b [1 ]
236+ c = 5 * c [0 ] + c [1 ]
237+ d = 5 * d [0 ] + d [1 ]
238+
239+ a = self .embedding .weight [a ].cpu ().detach ().numpy ()
240+ b = self .embedding .weight [b ].cpu ().detach ().numpy ()
241+ c = self .embedding .weight [c ].cpu ().detach ().numpy ()
242+ d = self .embedding .weight [d ].cpu ().detach ().numpy ()
243+ deviation = side_length_deviation (a , b , c , d )
244+ deviation_arr .append (deviation )
245+
246+ pca = PCA (n_components = 10 )
247+ emb_pca = pca .fit_transform (self .embedding .weight .cpu ().detach ().numpy ())
248+ pca .fit_transform (emb_pca )
249+ variances = pca .explained_variance_ratio_
250+
251+ result_dict = {
252+ 'parallelogram_quality' : np .mean (deviation_arr ),
253+ 'variances' : variances ,
254+ }
255+
256+ return result_dict
0 commit comments