77
88from tqdm import tqdm
99
10- from itertools import combinations
11- from sklearn .decomposition import PCA
12-
1310class MLP (nn .Module ):
1411 def __init__ (self , shp , vocab_size , embd_dim , input_token = 2 , init_scale = 1. , unembd = False , weight_tied = False , seed = 0 ):
1512 super (MLP , self ).__init__ ()
@@ -203,60 +200,3 @@ def train(self, param_dict: dict):
203200
204201 if (epoch + 1 ) % 50 == 0 :
205202 print (f"Epoch { epoch + 1 } /{ num_epochs } , Loss: { total_loss / len (dataloader ):.4f} " )
206-
207-
208- def eval (self ):
209- deviation_arr = []
210- points = [(i , j ) for i in range (5 ) for j in range (5 )]
211-
212-
213- def side_length_deviation (a , b , c , d ):
214- a , b , c , d = np .array (a ), np .array (b ), np .array (c ), np .array (d )
215-
216- # Compute lengths of opposite sides
217- length_ab = np .linalg .norm (b - a )
218- length_cd = np .linalg .norm (d - c )
219- length_ac = np .linalg .norm (c - a )
220- length_bd = np .linalg .norm (b - d )
221- length_bc = np .linalg .norm (c - b )
222- length_ad = np .linalg .norm (d - a )
223-
224- # Calculate side length deviation
225- side_deviation = np .sqrt ((length_ab - length_cd )** 2 + (length_ac - length_bd )** 2 ) / np .sqrt ((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2 )/ 2 )
226-
227- return side_deviation
228-
229- for quad in combinations (points , 3 ):
230- a , b , c = quad
231- d = (c [0 ] + b [0 ] - a [0 ], c [1 ] + b [1 ] - a [1 ])
232- if d [0 ] < 0 or d [0 ] >= 5 or d [1 ] < 0 or d [1 ] >= 5 :
233- continue
234-
235- if a [0 ] == b [0 ] and b [0 ] == c [0 ]:
236- continue
237- if a [1 ] == b [1 ] and b [1 ] == c [1 ]:
238- continue
239-
240- a = 5 * a [0 ] + a [1 ]
241- b = 5 * b [0 ] + b [1 ]
242- c = 5 * c [0 ] + c [1 ]
243- d = 5 * d [0 ] + d [1 ]
244-
245- a = self .embedding .weight [a ].cpu ().detach ().numpy ()
246- b = self .embedding .weight [b ].cpu ().detach ().numpy ()
247- c = self .embedding .weight [c ].cpu ().detach ().numpy ()
248- d = self .embedding .weight [d ].cpu ().detach ().numpy ()
249- deviation = side_length_deviation (a , b , c , d )
250- deviation_arr .append (deviation )
251-
252- pca = PCA (n_components = 10 )
253- emb_pca = pca .fit_transform (self .embedding .weight .cpu ().detach ().numpy ())
254- pca .fit_transform (emb_pca )
255- variances = pca .explained_variance_ratio_
256-
257- result_dict = {
258- 'parallelogram_quality' : np .mean (deviation_arr ),
259- 'variances' : variances ,
260- }
261-
262- return result_dict
0 commit comments