forked from Separius/SimCLRv2-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathverify.py
More file actions
71 lines (59 loc) · 2.36 KB
/
verify.py
File metadata and controls
71 lines (59 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import argparse
from collections import Counter
import torch
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from resnet import get_resnet, name_to_params
class ImagenetValidationDataset(Dataset):
def __init__(self, val_path):
super().__init__()
self.val_path = val_path
self.transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
with open(os.path.join(val_path, 'ILSVRC2012_validation_ground_truth.txt')) as f:
self.labels = [int(l) - 1 for l in f.readlines()]
def __len__(self):
return len(self.labels)
def __getitem__(self, item):
img = Image.open(os.path.join(self.val_path, f'ILSVRC2012_val_{item + 1:08d}.JPEG')).convert('RGB')
return self.transform(img), self.labels[item]
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t().cpu()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum().item()
res.append(correct_k)
return res
@torch.no_grad()
def run(pth_path):
device = 'cuda'
dataset = ImagenetValidationDataset('./val/')
data_loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=8)
model, _ = get_resnet(*name_to_params(pth_path))
model.load_state_dict(torch.load(pth_path)['resnet'])
model = model.to(device).eval()
preds = []
target = []
for images, labels in tqdm(data_loader):
_, pred = model(images.to(device), apply_fc=True).topk(1, dim=1)
preds.append(pred.squeeze(1).cpu())
target.append(labels)
p = torch.cat(preds).numpy()
t = torch.cat(target).numpy()
all_counters = [Counter() for i in range(1000)]
for i in range(50000):
all_counters[t[i]][p[i]] += 1
total_correct = 0
for i in range(1000):
total_correct += all_counters[i].most_common(1)[0][1]
print(f'ACC: {total_correct / 50000 * 100}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SimCLR verifier')
parser.add_argument('pth_path', type=str, help='path of the input checkpoint file')
args = parser.parse_args()
run(args.pth_path)