2020from scipy import stats
2121from torch .nn import functional as F
2222from torch .autograd import Variable
23+ from radam import RAdam , PlainRAdam , AdamW
2324
24- batch_size = 32
25- lr = 1e-5
25+ batch_size = 64
26+ lr = 1e-4
2627momentum = 0.99
27- num_epochs = 100
28+ num_epochs = 200
2829percentage_train = 0.8
2930percentage_val = 0.1
30- lr_decay = 0.5
31- step_size = 5
31+ lr_decay = 0.25
32+ step_size = 20
3233# loss_weights = [1,1e0,1e21,1e15]
33- loss_weights = [1 ,0.1 ,0.1 ,0.1 ,0.1 ]
34- nphi = 1
34+ loss_weights = [1 ,0.05 ,0.05 ,0.05 ,0.05 ]
35+ nphi = 4
3536plot_rate = 250
36- output_rate = 500
37- val_rate = 2000
37+ output_rate = 250
38+ val_rate = 1000
3839datapath = '/scratch/gpfs/marcoam/ml_collisions/data/xgc1/ti272_JET_heat_load/'
3940run_num = '00094/'
41+ lim = 80000
4042
4143device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4244
@@ -917,11 +919,11 @@ def load_data_hdf(iphi):
917919 i_df = hf_df ['i_df' ][iphi ]
918920
919921 ind1 ,ind2 ,ind3 = i_f .shape
920-
921- f = np .zeros ([ind2 ,2 ,ind1 ,ind1 ])
922- df = np .zeros ([ind2 ,2 ,ind1 ,ind1 ])
922+ #change lim back to ind2 if want full set
923+ f = np .zeros ([lim ,2 ,ind1 ,ind1 ])
924+ df = np .zeros ([lim ,2 ,ind1 ,ind1 ])
923925
924- for n in range (ind2 ):
926+ for n in range (lim ):
925927 f [n ,0 ,:,:- 1 ] = e_f [:,n ,:]
926928 f [n ,1 ,:,:- 1 ] = i_f [:,n ,:]
927929 df [n ,0 ,:,:- 1 ] = e_df [:,n ,:]
@@ -943,7 +945,7 @@ def load_data_hdf(iphi):
943945 hf_stats = h5py .File (datapath + run_num + 'hdf_stats.h5' ,'r' )
944946 zvars = stats_variables (hf_stats )
945947
946- for n in range (ind2 ):
948+ for n in range (lim ):
947949 f [n ] = (f [n ]- zvars .mean_f )/ zvars .std_f
948950# df[n] = (df[n]-zvars.mean_df)/zvars.std_df
949951 df [n ] = (df [n ]- zvars .mean_fdf )/ zvars .std_fdf
@@ -970,7 +972,7 @@ def load_data_hdf(iphi):
970972 zvars .std_df = torch .from_numpy (zvars .std_df ).to (device ).float ()
971973 zvars .std_fdf = torch .from_numpy (zvars .std_fdf ).to (device ).float ()
972974
973- return f ,df ,ind2 ,zvars ,cons
975+ return f ,df ,lim ,zvars ,cons
974976
975977class stats_variables ():
976978
@@ -1204,15 +1206,16 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12041206 data_unnorm = data * zvars .std_f [:nbatch ] + zvars .mean_f [:nbatch ]
12051207 targets_unnorm = targets * zvars .std_fdf [:nbatch ] + zvars .mean_fdf [:nbatch ]
12061208 outputs_unnorm = outputs [:,0 ]* zvars .std_fdf [:nbatch ,1 ] + zvars .mean_fdf [:nbatch ,1 ]
1207-
1209+
1210+ # don't think I need some of these nbatch but unsure
12081211 targets_nof = targets_unnorm - data_unnorm
12091212 outputs_nof = outputs_unnorm [:nbatch ] - data_unnorm [:nbatch ,1 ]
12101213
12111214 outputs_nof_to_cat = outputs_nof [:nbatch ].unsqueeze (1 )
12121215 targets_nof_to_cat = targets_nof [:nbatch ,0 ].unsqueeze (1 )
12131216
12141217 # concatenate with actual dfe
1215- outputs_nof = torch .cat ((outputs_nof_to_cat , targets_nof_to_cat ),1 )
1218+ outputs_nof = torch .cat ((targets_nof_to_cat , outputs_nof_to_cat ),1 )
12161219
12171220 # updated calls to check_properties with correct arguments
12181221 masse_b ,massi_b ,mom_b ,energy_b = check_properties_main (data_unnorm [:,:,:,:- 1 ],\
@@ -1233,17 +1236,18 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12331236 mom_loss = torch .sum (mom_a )/ nbatch
12341237 energy_loss = torch .sum (energy_a )/ nbatch
12351238
1236- if i % 200 == 199 :
1237- print ('outputs' ,masse_loss .item (),massi_loss .item (),mom_loss .item (),energy_loss .item ())
1238-
12391239 #masse_loss = torch.sum(torch.abs(masse_a - masse_b)).float()/nbatch
12401240 #massi_loss = torch.sum(torch.abs(massi_a - massi_b)).float()/nbatch
12411241 #mass_loss = massi_loss + masse_loss
12421242 #mom_loss = torch.sum(torch.abs(mom_a - mom_b)).float()/nbatch
12431243 #energy_loss = torch.sum(torch.abs(energy_a - energy_b)).float()/nbatch
12441244
1245- l2_loss = criterion (outputs [:,0 ],targets [:,1 ])
1246-
1245+ l2_loss = criterion (outputs [:,0 ],targets [:,1 ])
1246+
1247+ if i % 100 == 99 :
1248+ print ('masse' ,masse_loss .item (),'massi' ,massi_loss .item (),'mom' ,mom_loss .item (),'en' ,energy_loss .item (),'l2' ,l2_loss .item ())
1249+
1250+
12471251# loss = l2_loss*loss_weights[0] \
12481252# + mass_loss*loss_weights[1] \
12491253# + mom_loss*loss_weights[2] \
@@ -1273,7 +1277,7 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12731277 running_loss += loss .item ()
12741278 running_l2_loss += l2_loss .item ()
12751279 running_cons_loss += cons_loss .item ()
1276-
1280+
12771281 if i % output_rate == output_rate - 1 :
12781282 print (' [%d, %5d] loss: %.6f' %
12791283 (epoch + 1 , end + i + 1 , running_loss / output_rate ))
@@ -1300,7 +1304,7 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
13001304 if val_loss < np .min (val_loss_vector ): ## check this
13011305 is_best = True
13021306
1303- if i % ( 2 * val_rate ) == ( 2 * val_rate - 1 ) :
1307+ if i % val_rate == val_rate - 1 :
13041308 save_checkpoint ({
13051309 'epoch' : epoch + 1 ,
13061310 'state_dict' : net .state_dict (),
@@ -1344,7 +1348,7 @@ def validate(valloader,cons,zvars):
13441348 targets_nof_to_cat = targets_nof [:nbatch ,0 ].unsqueeze (1 )
13451349
13461350 # concatenate with actual dfe
1347- outputs_nof = torch .cat ((outputs_nof_to_cat , targets_nof_to_cat ),1 )
1351+ outputs_nof = torch .cat ((targets_nof_to_cat , outputs_nof_to_cat ),1 )
13481352
13491353 masse_b ,massi_b ,mom_b ,energy_b = check_properties_main (data_unnorm [:,:,:,:- 1 ],\
13501354 targets_nof [:,:,:,:- 1 ],temp ,vol ,cons )
@@ -1503,11 +1507,9 @@ def plot_df(df_xgc,df_ml,epoch):
15031507criterion = nn .MSELoss ()
15041508
15051509optimizer = optim .SGD (net .parameters (), lr = lr , momentum = momentum )
1506- #optimizer = optim.Adam (net.parameters(), lr=lr )
1510+ #optimizer = RAdam (net.parameters())
15071511scheduler = optim .lr_scheduler .StepLR (optimizer ,step_size = step_size ,gamma = lr_decay )
1508-
1509- train_loss = []
1510- val_loss = []
1512+ #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
15111513
15121514for epoch in range (num_epochs ):
15131515
@@ -1554,9 +1556,13 @@ def plot_df(df_xgc,df_ml,epoch):
15541556
15551557 for loss1 in train_loss_to_app :
15561558 train_loss .append (loss1 )
1557- for loss2 in val_loss_to_app :
1558- val_loss .append (loss2 )
1559- cons_array = np .concatenate ((cons_array , cons_to_cat ), axis = 1 )
1559+ for loss2 in l2_loss_to_app :
1560+ l2_loss .append (loss2 )
1561+ for loss3 in cons_loss_to_app :
1562+ cons_loss .append (loss3 )
1563+ for loss4 in val_loss_to_app :
1564+ val_loss .append (loss4 )
1565+ # cons_array = np.concatenate((cons_array, cons_to_cat), axis=1)
15601566
15611567 else :
15621568 del f_test ,df_test ,temp_test ,vol_test
@@ -1572,19 +1578,22 @@ def plot_df(df_xgc,df_ml,epoch):
15721578 cons_loss .append (loss3 )
15731579 for loss4 in val_loss_to_app :
15741580 val_loss .append (loss4 )
1575- cons_array = np .concatenate ((cons_array , cons_to_cat ), axis = 0 )
1581+ # cons_array = np.concatenate((cons_array, cons_to_cat), axis=0)
15761582
15771583 train2 = timeit .default_timer ()
15781584 print ('Finished tranining iphi = {}' .format (iphi ))
15791585 print (' Training time for iphi = %d: %.3fs' % (iphi ,train2 - train1 ))
15801586
1581- train_iterations = np .linspace (1 ,len (train_loss ),len (train_loss ))
1582- val_iterations = np .linspace (2 ,len (train_loss ),len (val_loss ))
1583-
1587+ # train_iterations = np.linspace(1,len(train_loss),len(train_loss))
1588+ # val_iterations = np.linspace(2,len(train_loss),len(val_loss))
1589+
15841590 fid_loss1 = open ('train_tmp.txt' ,'w' )
15851591 fid_loss2 = open ('val_tmp.txt' ,'w' )
15861592 fid_loss3 = open ('l2_tmp.txt' ,'w' )
15871593 fid_loss4 = open ('cons_tmp.txt' ,'w' )
1594+ lr_command = 'w' if epoch == 0 else 'a'
1595+ fid_lr = open ('lr.txt' ,lr_command )
1596+
15881597 for loss in train_loss :
15891598 fid_loss1 .write (str (loss )+ '\n ' )
15901599 for loss in val_loss :
@@ -1593,21 +1602,25 @@ def plot_df(df_xgc,df_ml,epoch):
15931602 fid_loss3 .write (str (loss )+ '\n ' )
15941603 for loss in cons_loss :
15951604 fid_loss4 .write (str (loss )+ '\n ' )
1605+ fid_lr .write (str (lr_epoch )+ '\n ' )
1606+
15961607 fid_loss1 .close ()
15971608 fid_loss2 .close ()
15981609 fid_loss3 .close ()
15991610 fid_loss4 .close ()
1611+ fid_lr .close ()
16001612
1601- plt .plot (train_iterations ,train_loss ,'-o' ,color = 'blue' )
1602- plt .plot (val_iterations ,val_loss ,'-o' ,color = 'orange' )
1603- plt .plot (train_iterations ,l2_loss ,'-o' ,color = 'red' )
1604- plt .plot (train_iterations ,cons_loss ,'-o' ,color = 'green' )
1605- plt .legend (['total' ,'validation' ,'l2' ,'cons' ])
1606- plt .yscale ('log' )
1607- plt .show ()
1613+ # plt.plot(train_iterations,train_loss,'-o',color='blue')
1614+ # plt.plot(val_iterations,val_loss,'-o',color='orange')
1615+ # plt.plot(train_iterations,l2_loss,'-o',color='red')
1616+ # plt.plot(train_iterations,cons_loss,'-o',color='green')
1617+ # plt.legend(['total','validation','l2','cons'])
1618+ # plt.yscale('log')
1619+ # plt.show()
16081620
16091621 epoch2 = timeit .default_timer ()
16101622 scheduler .step ()
1623+ #scheduler.step(val_loss[-1])
16111624 print ('Epoch time: {}s\n ' .format (epoch2 - epoch1 ))
16121625
16131626print ('Starting testing' )
0 commit comments