-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_7cls.py
More file actions
60 lines (53 loc) · 2.08 KB
/
test_7cls.py
File metadata and controls
60 lines (53 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import cv2
import numpy as np
import torch
from torch.autograd import Function
from torchvision import models,transforms
import torch.nn.functional as F
import torch.nn as nn
import pretrainedmodels
from glob import glob
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from DatasetGenerator import DatasetGenerator
#setting info
label_7cls=["normal","oa_I","oa_II","oa_III","onfh_II","onfh_III","onfh_IV"]
#load trained model
modelCheckpoint = torch.load("./models/hip_7cls_model.pth.tar") # 3cls best model
model = pretrainedmodels.__dict__['xception'](num_classes=1000,
pretrained='imagenet')
num_fc = model.last_linear.in_features
model.last_linear = nn.Linear(num_fc, 7)
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(modelCheckpoint['state_dict'])
#prepare dataset
test_data_path = './sample_images/hip_7cls/test/'
test_file = './sample_images/hip_7cls/test.txt'
mean = 0.605
std = 0.156
mean = [mean, mean, mean]
std = [std, std, std]
normalize = transforms.Normalize(mean,std)
test_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.TenCrop(224),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])
datasetTest = DatasetGenerator(pathImageDirectory=test_data_path, pathDatasetFile=test_file,
transform=test_transform)
dataloaderTest = DataLoader(dataset=datasetTest,batch_size=1,num_workers=8, pin_memory=True)
#test_data
model.eval()
cudnn.benchmark = True
for step, (input, target) in enumerate(dataloaderTest):
target = target.cuda()
bs, n_crops, c, h, w = input.size()
with torch.no_grad():
out = model(input.view(-1, c, h, w).cuda())
outMean = out.view(bs, n_crops, -1).mean(1)
pred = torch.max(outMean, 1)[1].item()
gt = torch.max(target,1)[1].item()
print("pred:",label_7cls[pred],'gt:',label_7cls[gt])