Skip to content

Commit 12146c2

Browse files
authored
Add files via upload
1 parent 217c96c commit 12146c2

File tree

3 files changed

+2001
-43
lines changed

3 files changed

+2001
-43
lines changed

full_both_sp.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,25 @@
2020
from scipy import stats
2121
from torch.nn import functional as F
2222
from torch.autograd import Variable
23+
from radam import RAdam, PlainRAdam, AdamW
2324

24-
batch_size = 32
25-
lr = 1e-5
25+
batch_size = 64
26+
lr = 1e-4
2627
momentum = 0.99
27-
num_epochs = 100
28+
num_epochs = 200
2829
percentage_train = 0.8
2930
percentage_val = 0.1
30-
lr_decay = 0.5
31-
step_size = 5
31+
lr_decay = 0.25
32+
step_size = 20
3233
# loss_weights = [1,1e0,1e21,1e15]
33-
loss_weights = [1,0.1,0.1,0.1,0.1]
34-
nphi = 1
34+
loss_weights = [1,0.05,0.05,0.05,0.05]
35+
nphi = 4
3536
plot_rate = 250
36-
output_rate = 500
37-
val_rate = 2000
37+
output_rate = 250
38+
val_rate = 1000
3839
datapath = '/scratch/gpfs/marcoam/ml_collisions/data/xgc1/ti272_JET_heat_load/'
3940
run_num = '00094/'
41+
lim = 80000
4042

4143
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4244

@@ -917,11 +919,11 @@ def load_data_hdf(iphi):
917919
i_df = hf_df['i_df'][iphi]
918920

919921
ind1,ind2,ind3 = i_f.shape
920-
921-
f = np.zeros([ind2,2,ind1,ind1])
922-
df = np.zeros([ind2,2,ind1,ind1])
922+
#change lim back to ind2 if want full set
923+
f = np.zeros([lim,2,ind1,ind1])
924+
df = np.zeros([lim,2,ind1,ind1])
923925

924-
for n in range(ind2):
926+
for n in range(lim):
925927
f[n,0,:,:-1] = e_f[:,n,:]
926928
f[n,1,:,:-1] = i_f[:,n,:]
927929
df[n,0,:,:-1] = e_df[:,n,:]
@@ -943,7 +945,7 @@ def load_data_hdf(iphi):
943945
hf_stats = h5py.File(datapath+run_num+'hdf_stats.h5','r')
944946
zvars = stats_variables(hf_stats)
945947

946-
for n in range(ind2):
948+
for n in range(lim):
947949
f[n] = (f[n]-zvars.mean_f)/zvars.std_f
948950
# df[n] = (df[n]-zvars.mean_df)/zvars.std_df
949951
df[n] = (df[n]-zvars.mean_fdf)/zvars.std_fdf
@@ -970,7 +972,7 @@ def load_data_hdf(iphi):
970972
zvars.std_df = torch.from_numpy(zvars.std_df).to(device).float()
971973
zvars.std_fdf = torch.from_numpy(zvars.std_fdf).to(device).float()
972974

973-
return f,df,ind2,zvars,cons
975+
return f,df,lim,zvars,cons
974976

975977
class stats_variables():
976978

@@ -1204,15 +1206,16 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12041206
data_unnorm = data*zvars.std_f[:nbatch] + zvars.mean_f[:nbatch]
12051207
targets_unnorm = targets*zvars.std_fdf[:nbatch] + zvars.mean_fdf[:nbatch]
12061208
outputs_unnorm = outputs[:,0]*zvars.std_fdf[:nbatch,1] + zvars.mean_fdf[:nbatch,1]
1207-
1209+
1210+
# don't think I need some of these nbatch but unsure
12081211
targets_nof = targets_unnorm - data_unnorm
12091212
outputs_nof = outputs_unnorm[:nbatch] - data_unnorm[:nbatch,1]
12101213

12111214
outputs_nof_to_cat = outputs_nof[:nbatch].unsqueeze(1)
12121215
targets_nof_to_cat = targets_nof[:nbatch,0].unsqueeze(1)
12131216

12141217
# concatenate with actual dfe
1215-
outputs_nof = torch.cat((outputs_nof_to_cat,targets_nof_to_cat),1)
1218+
outputs_nof = torch.cat((targets_nof_to_cat,outputs_nof_to_cat),1)
12161219

12171220
# updated calls to check_properties with correct arguments
12181221
masse_b,massi_b,mom_b,energy_b = check_properties_main(data_unnorm[:,:,:,:-1],\
@@ -1233,17 +1236,18 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12331236
mom_loss = torch.sum(mom_a)/nbatch
12341237
energy_loss = torch.sum(energy_a)/nbatch
12351238

1236-
if i % 200 == 199:
1237-
print('outputs',masse_loss.item(),massi_loss.item(),mom_loss.item(),energy_loss.item())
1238-
12391239
#masse_loss = torch.sum(torch.abs(masse_a - masse_b)).float()/nbatch
12401240
#massi_loss = torch.sum(torch.abs(massi_a - massi_b)).float()/nbatch
12411241
#mass_loss = massi_loss + masse_loss
12421242
#mom_loss = torch.sum(torch.abs(mom_a - mom_b)).float()/nbatch
12431243
#energy_loss = torch.sum(torch.abs(energy_a - energy_b)).float()/nbatch
12441244

1245-
l2_loss = criterion(outputs[:,0],targets[:,1])
1246-
1245+
l2_loss = criterion(outputs[:,0],targets[:,1])
1246+
1247+
if i % 100 == 99:
1248+
print('masse',masse_loss.item(),'massi',massi_loss.item(),'mom',mom_loss.item(),'en',energy_loss.item(),'l2',l2_loss.item())
1249+
1250+
12471251
# loss = l2_loss*loss_weights[0] \
12481252
# + mass_loss*loss_weights[1] \
12491253
# + mom_loss*loss_weights[2] \
@@ -1273,7 +1277,7 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
12731277
running_loss += loss.item()
12741278
running_l2_loss += l2_loss.item()
12751279
running_cons_loss += cons_loss.item()
1276-
1280+
12771281
if i % output_rate == output_rate-1:
12781282
print(' [%d, %5d] loss: %.6f' %
12791283
(epoch + 1, end + i + 1, running_loss / output_rate))
@@ -1300,7 +1304,7 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
13001304
if val_loss < np.min(val_loss_vector): ## check this
13011305
is_best = True
13021306

1303-
if i % (2*val_rate) == (2*val_rate-1):
1307+
if i % val_rate == val_rate-1:
13041308
save_checkpoint({
13051309
'epoch': epoch+1,
13061310
'state_dict': net.state_dict(),
@@ -1344,7 +1348,7 @@ def validate(valloader,cons,zvars):
13441348
targets_nof_to_cat = targets_nof[:nbatch,0].unsqueeze(1)
13451349

13461350
# concatenate with actual dfe
1347-
outputs_nof = torch.cat((outputs_nof_to_cat,targets_nof_to_cat),1)
1351+
outputs_nof = torch.cat((targets_nof_to_cat,outputs_nof_to_cat),1)
13481352

13491353
masse_b,massi_b,mom_b,energy_b = check_properties_main(data_unnorm[:,:,:,:-1],\
13501354
targets_nof[:,:,:,:-1],temp,vol,cons)
@@ -1503,11 +1507,9 @@ def plot_df(df_xgc,df_ml,epoch):
15031507
criterion = nn.MSELoss()
15041508

15051509
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
1506-
#optimizer = optim.Adam(net.parameters(), lr=lr)
1510+
#optimizer = RAdam(net.parameters())
15071511
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=step_size,gamma=lr_decay)
1508-
1509-
train_loss = []
1510-
val_loss = []
1512+
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
15111513

15121514
for epoch in range(num_epochs):
15131515

@@ -1554,9 +1556,13 @@ def plot_df(df_xgc,df_ml,epoch):
15541556

15551557
for loss1 in train_loss_to_app:
15561558
train_loss.append(loss1)
1557-
for loss2 in val_loss_to_app:
1558-
val_loss.append(loss2)
1559-
cons_array = np.concatenate((cons_array, cons_to_cat), axis=1)
1559+
for loss2 in l2_loss_to_app:
1560+
l2_loss.append(loss2)
1561+
for loss3 in cons_loss_to_app:
1562+
cons_loss.append(loss3)
1563+
for loss4 in val_loss_to_app:
1564+
val_loss.append(loss4)
1565+
# cons_array = np.concatenate((cons_array, cons_to_cat), axis=1)
15601566

15611567
else:
15621568
del f_test,df_test,temp_test,vol_test
@@ -1572,19 +1578,22 @@ def plot_df(df_xgc,df_ml,epoch):
15721578
cons_loss.append(loss3)
15731579
for loss4 in val_loss_to_app:
15741580
val_loss.append(loss4)
1575-
cons_array = np.concatenate((cons_array, cons_to_cat), axis=0)
1581+
#cons_array = np.concatenate((cons_array, cons_to_cat), axis=0)
15761582

15771583
train2 = timeit.default_timer()
15781584
print('Finished tranining iphi = {}'.format(iphi))
15791585
print(' Training time for iphi = %d: %.3fs' % (iphi,train2-train1))
15801586

1581-
train_iterations = np.linspace(1,len(train_loss),len(train_loss))
1582-
val_iterations = np.linspace(2,len(train_loss),len(val_loss))
1583-
1587+
#train_iterations = np.linspace(1,len(train_loss),len(train_loss))
1588+
#val_iterations = np.linspace(2,len(train_loss),len(val_loss))
1589+
15841590
fid_loss1 = open('train_tmp.txt','w')
15851591
fid_loss2 = open('val_tmp.txt','w')
15861592
fid_loss3 = open('l2_tmp.txt','w')
15871593
fid_loss4 = open('cons_tmp.txt','w')
1594+
lr_command = 'w' if epoch == 0 else 'a'
1595+
fid_lr = open('lr.txt',lr_command)
1596+
15881597
for loss in train_loss:
15891598
fid_loss1.write(str(loss)+'\n')
15901599
for loss in val_loss:
@@ -1593,21 +1602,25 @@ def plot_df(df_xgc,df_ml,epoch):
15931602
fid_loss3.write(str(loss)+'\n')
15941603
for loss in cons_loss:
15951604
fid_loss4.write(str(loss)+'\n')
1605+
fid_lr.write(str(lr_epoch)+'\n')
1606+
15961607
fid_loss1.close()
15971608
fid_loss2.close()
15981609
fid_loss3.close()
15991610
fid_loss4.close()
1611+
fid_lr.close()
16001612

1601-
plt.plot(train_iterations,train_loss,'-o',color='blue')
1602-
plt.plot(val_iterations,val_loss,'-o',color='orange')
1603-
plt.plot(train_iterations,l2_loss,'-o',color='red')
1604-
plt.plot(train_iterations,cons_loss,'-o',color='green')
1605-
plt.legend(['total','validation','l2','cons'])
1606-
plt.yscale('log')
1607-
plt.show()
1613+
#plt.plot(train_iterations,train_loss,'-o',color='blue')
1614+
#plt.plot(val_iterations,val_loss,'-o',color='orange')
1615+
#plt.plot(train_iterations,l2_loss,'-o',color='red')
1616+
#plt.plot(train_iterations,cons_loss,'-o',color='green')
1617+
#plt.legend(['total','validation','l2','cons'])
1618+
#plt.yscale('log')
1619+
#plt.show()
16081620

16091621
epoch2 = timeit.default_timer()
16101622
scheduler.step()
1623+
#scheduler.step(val_loss[-1])
16111624
print('Epoch time: {}s\n'.format(epoch2-epoch1))
16121625

16131626
print('Starting testing')

0 commit comments

Comments
 (0)