forked from lgraesser/CNN-Tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_and_eval.py
More file actions
56 lines (52 loc) · 2.21 KB
/
train_and_eval.py
File metadata and controls
56 lines (52 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch.autograd import Variable
def train(epoch, net, dataloader, criterion, optimizer, cuda, batch_size):
net.train()
correct = 0
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
if data.size(0) != batch_size:
continue
data, target = Variable(data), Variable(target)
if torch.cuda.is_available() and cuda:
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
'''Keep track of accuracy and loss'''
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
total_loss += loss.data[0]
'''Report progress'''
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(dataloader.dataset),
100. * batch_idx / len(dataloader), loss.data[0]))
total_loss = (total_loss * batch_size) / len(dataloader.dataset)
print('During training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
total_loss, correct, len(dataloader.dataset),
100. * correct / len(dataloader.dataset)))
def evaluate(epoch, net, dataloader, criterion, cuda, batch_size):
net.eval()
correct = 0
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
if data.size(0) != batch_size:
continue
data, target = Variable(data), Variable(target)
if torch.cuda.is_available() and cuda:
data = data.cuda()
target = target.cuda()
output = net(data)
loss = criterion(output, target)
'''Keep track of accuracy and loss'''
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
total_loss += loss.data[0]
total_loss = (total_loss * batch_size) / len(dataloader.dataset)
print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
total_loss, correct, len(dataloader.dataset),
100. * correct / len(dataloader.dataset)))