Skip to content

Commit 313d953

Browse files
author
Marco A. Miller
committed
One more change to datatypes
1 parent 0858d3f commit 313d953

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747

4848
"""# choose network"""
4949

50-
net = Unet().to(device)
51-
#net = ReSeg().to(device)
50+
#net = Unet().to(device)
51+
net = ReSeg().to(device)
5252
#net = StackedRecurrentHourglass().to(device)
5353

5454
#print(sum(p.numel() for p in net.parameters() if p.requires_grad))
@@ -388,7 +388,7 @@ def train(trainloader,valloader,sp_flag,epoch,end,zvars,cons):
388388
else:
389389
optimizer_e.zero_grad()
390390

391-
outputs = net(data)
391+
outputs = net(data.float()).double()
392392
outputs = outputs.to(device)
393393

394394
nbatch = len(data)
@@ -568,7 +568,7 @@ def test(f_test,df_test,temp_test,vol_test):
568568

569569
data, targets, temp, vol = data.to(device), targets.to(device), temp.to(device), vol.to(device)
570570

571-
outputs = net(data)
571+
outputs = net(data.float()).double()
572572
outputs = outputs.to(device)
573573

574574
nbatch = len(data)

0 commit comments

Comments
 (0)