@@ -126,9 +126,7 @@ def __getitem__(self, idx):
126126 if torch .is_tensor (idx ):
127127 idx = idx .tolist ()
128128
129- return torch .tensor (self ._images [idx ]), torch .tensor (
130- self ._labels [idx ]
131- )
129+ return torch .tensor (self ._images [idx ]), torch .tensor (self ._labels [idx ])
132130
133131
134132def _read32 (bytestream ):
@@ -225,19 +223,19 @@ class DataSets(object):
225223
226224 SOURCE_URL = "http://www.cs.toronto.edu/~jmartens/"
227225 TRAIN_IMAGES = "newfaces_rot_single.mat"
228-
226+
229227 local_file = maybe_download (SOURCE_URL , TRAIN_IMAGES , train_dir )
230228 print (f"Data read from { local_file } " )
231-
232- numpy_file = os .path .dirname (local_file ) + ' /faces.npy'
229+
230+ numpy_file = os .path .dirname (local_file ) + " /faces.npy"
233231 if os .path .exists (numpy_file ):
234232 images_ = np .load (numpy_file )
235233 else :
236234 import mat4py
237235
238236 images_ = mat4py .loadmat (local_file )
239237 images_ = np .asarray (images_ ["newfaces_single" ])
240-
238+
241239 images_ = np .transpose (images_ )
242240 np .save (numpy_file , images_ )
243241 print (f"Data saved to { numpy_file } " )
@@ -276,10 +274,9 @@ class DataSets(object):
276274if __name__ == "__main__" :
277275
278276 argparser = argparse .ArgumentParser ()
279- argparser .add_argument (' --exp' , type = str , help = ' which dataset' , default = ' FACES' )
277+ argparser .add_argument (" --exp" , type = str , help = " which dataset" , default = " FACES" )
280278 args = argparser .parse_args ()
281279
282-
283280 seed = 13
284281 torch .manual_seed (seed )
285282 torch .backends .cudnn .benchmark = False
@@ -291,7 +288,7 @@ class DataSets(object):
291288 print ("device" , device )
292289
293290 ## Hyperparams
294- if args .exp == ' FACES' :
291+ if args .exp == " FACES" :
295292
296293 batch_size = 100
297294 epochs = 5
@@ -304,8 +301,8 @@ class DataSets(object):
304301 damping = 1.0
305302
306303 dataset = read_data_sets ("FACES" , "../data/" , if_autoencoder = True )
307-
308- if args .exp == ' MNIST' :
304+
305+ if args .exp == " MNIST" :
309306 batch_size = 100
310307 epochs = 10
311308 eta_adam = 1e-4
@@ -321,19 +318,23 @@ class DataSets(object):
321318 ## Dataset
322319 train_dataset = dataset .train
323320 test_dataset = dataset .test
324- if args .exp == 'FACES' :
325- likelihood = FISH_LIKELIHOODS ['fixedgaussian' ](sigma = 1.0 , device = device )
321+ if args .exp == "FACES" :
322+ likelihood = FISH_LIKELIHOODS ["fixedgaussian" ](sigma = 1.0 , device = device )
323+
326324 def mse (model , data ):
327325 data_x , data_y = data
328326 pred_y = model .forward (data_x )
329- return torch .mean (torch .square (pred_y - data_y ))
330- if args .exp == 'MNIST' :
331- likelihood = FISH_LIKELIHOODS ['bernoulli' ](device = device )
327+ return torch .mean (torch .square (pred_y - data_y ))
328+
329+ if args .exp == "MNIST" :
330+ likelihood = FISH_LIKELIHOODS ["bernoulli" ](device = device )
331+
332332 def mse (model , data ):
333333 data_x , data_y = data
334334 pred_y = model .forward (data_x )
335335 pred_y = torch .sigmoid (pred_y )
336- return torch .mean (torch .square (pred_y - data_y ))
336+ return torch .mean (torch .square (pred_y - data_y ))
337+
337338 def nll (model , data ):
338339 data_x , data_y = data
339340 pred_y = model .forward (data_x )
@@ -344,7 +345,6 @@ def draw(model, data):
344345 pred_y = model .forward (data_x )
345346 return (data_x , likelihood .draw (pred_y ))
346347
347-
348348 train_loader = torch .utils .data .DataLoader (
349349 train_dataset , batch_size = batch_size , shuffle = True
350350 )
@@ -360,43 +360,42 @@ def draw(model, data):
360360 test_dataset , batch_size = 1000 , shuffle = False
361361 )
362362
363-
364- if args .exp == 'FACES' :
363+ if args .exp == "FACES" :
365364 model = nn .Sequential (
366- nn .Linear (625 , 2000 ),
367- nn .ReLU (),
368- nn .Linear (2000 , 1000 ),
369- nn .ReLU (),
370- nn .Linear (1000 , 500 ),
371- nn .ReLU (),
372- nn .Linear (500 , 30 ),
373- nn .Linear (30 , 500 ),
374- nn .ReLU (),
375- nn .Linear (500 , 1000 ),
376- nn .ReLU (),
377- nn .Linear (1000 , 2000 ),
378- nn .ReLU (),
379- nn .Linear (2000 , 625 ),
365+ nn .Linear (625 , 2000 ),
366+ nn .ReLU (),
367+ nn .Linear (2000 , 1000 ),
368+ nn .ReLU (),
369+ nn .Linear (1000 , 500 ),
370+ nn .ReLU (),
371+ nn .Linear (500 , 30 ),
372+ nn .Linear (30 , 500 ),
373+ nn .ReLU (),
374+ nn .Linear (500 , 1000 ),
375+ nn .ReLU (),
376+ nn .Linear (1000 , 2000 ),
377+ nn .ReLU (),
378+ nn .Linear (2000 , 625 ),
380379 ).to (device )
381-
382- if args .exp == ' MNIST' :
380+
381+ if args .exp == " MNIST" :
383382 model = nn .Sequential (
384- nn .Linear (784 , 1000 , dtype = torch .float32 ),
385- nn .ReLU (),
386- nn .Linear (1000 , 500 , dtype = torch .float32 ),
387- nn .ReLU (),
388- nn .Linear (500 , 250 , dtype = torch .float32 ),
389- nn .ReLU (),
390- nn .Linear (250 , 30 , dtype = torch .float32 ),
391- nn .Linear (30 , 250 , dtype = torch .float32 ),
392- nn .ReLU (),
393- nn .Linear (250 , 500 , dtype = torch .float32 ),
394- nn .ReLU (),
395- nn .Linear (500 , 1000 , dtype = torch .float32 ),
396- nn .ReLU (),
397- nn .Linear (1000 , 784 , dtype = torch .float32 ),
383+ nn .Linear (784 , 1000 , dtype = torch .float32 ),
384+ nn .ReLU (),
385+ nn .Linear (1000 , 500 , dtype = torch .float32 ),
386+ nn .ReLU (),
387+ nn .Linear (500 , 250 , dtype = torch .float32 ),
388+ nn .ReLU (),
389+ nn .Linear (250 , 30 , dtype = torch .float32 ),
390+ nn .Linear (30 , 250 , dtype = torch .float32 ),
391+ nn .ReLU (),
392+ nn .Linear (250 , 500 , dtype = torch .float32 ),
393+ nn .ReLU (),
394+ nn .Linear (500 , 1000 , dtype = torch .float32 ),
395+ nn .ReLU (),
396+ nn .Linear (1000 , 784 , dtype = torch .float32 ),
398397 ).to (device )
399-
398+
400399 model_adam = copy .deepcopy (model )
401400
402401 print ("lr fl={}, lr sgd={}, lr aux={}" .format (eta_fl , eta_sgd , aux_eta ))
@@ -417,10 +416,15 @@ def draw(model, data):
417416 damping = damping ,
418417 pre_aux_training = 0 ,
419418 sgd_lr = eta_sgd ,
420- initialization = ' normal' ,
419+ initialization = " normal" ,
421420 device = device ,
422421 )
423422
423+ print (opt .__dict__ ["fish_lr" ])
424+ print (opt .__dict__ ["beta" ])
425+ print (opt .__dict__ ["aux_lr" ])
426+ print (opt .__dict__ ["damping" ])
427+ print (opt .__dict__ ["sgd_lr" ])
424428
425429 FL_time = []
426430 LOSS = []
@@ -429,7 +433,7 @@ def draw(model, data):
429433 iteration = 0
430434 for e in range (1 , epochs + 1 ):
431435 print ("######## EPOCH" , e )
432- for n , (batch_data , batch_labels ) in enumerate (train_loader ):
436+ for n , (batch_data , batch_labels ) in enumerate (train_loader , start = 1 ):
433437 iteration += 1
434438 batch_data , batch_labels = batch_data .to (device ), batch_labels .to (device )
435439 opt .zero_grad ()
@@ -440,16 +444,18 @@ def draw(model, data):
440444 if n % 50 == 0 :
441445 FL_time .append (time .time () - st )
442446 LOSS .append (loss .detach ().cpu ().numpy ())
443-
447+
444448 test_batch_data , test_batch_labels = next (iter (test_loader ))
445- test_batch_data , test_batch_labels = test_batch_data .to (device ), test_batch_labels .to (device )
449+ test_batch_data , test_batch_labels = test_batch_data .to (
450+ device
451+ ), test_batch_labels .to (device )
446452 test_loss = mse (opt .model , (test_batch_data , test_batch_labels ))
447-
453+
448454 TEST_LOSS .append (test_loss .detach ().cpu ().numpy ())
449455
450456 print (n , LOSS [- 1 ], TEST_LOSS [- 1 ])
451-
452- fig , axs = plt .subplots (1 ,2 , figsize = (10 ,5 ))
457+
458+ fig , axs = plt .subplots (1 , 2 , figsize = (10 , 5 ))
453459 axs [0 ].plot (FL_time , LOSS , label = "Fishleg" ) # color=colors_group[i])
454460 axs [1 ].plot (
455461 FL_time , TEST_LOSS , label = "Fishleg"
@@ -478,21 +484,23 @@ def draw(model, data):
478484 opt .step ()
479485
480486 if n % 50 == 0 :
481- FL_time .append (time .time ()- st )
487+ FL_time .append (time .time () - st )
482488 LOSS .append (loss .detach ().cpu ().numpy ())
483489 test_batch_data , test_batch_labels = next (iter (test_loader_adam ))
484- test_batch_data , test_batch_labels = test_batch_data .to (device ), test_batch_labels .to (device )
490+ test_batch_data , test_batch_labels = test_batch_data .to (
491+ device
492+ ), test_batch_labels .to (device )
485493 test_loss = mse (model_adam , (test_batch_data , test_batch_labels ))
486494 TEST_LOSS .append (test_loss .detach ().cpu ().numpy ())
487495
488496 print (n , LOSS [- 1 ], TEST_LOSS [- 1 ])
489497
490498 axs [0 ].plot (FL_time , LOSS , label = "Adam" )
491499 axs [1 ].plot (FL_time , TEST_LOSS , label = "Adam" )
492-
500+
493501 axs [0 ].legend ()
494502 axs [1 ].legend ()
495503
496- axs [0 ].set_title (' Training Loss' )
497- axs [1 ].set_title (' Test MSE' )
504+ axs [0 ].set_title (" Training Loss" )
505+ axs [1 ].set_title (" Test MSE" )
498506 fig .savefig ("result/result.png" , dpi = 300 )
0 commit comments