forked from YuZheng9/C2PNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdehaze.py
More file actions
71 lines (61 loc) · 2.17 KB
/
dehaze.py
File metadata and controls
71 lines (61 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
import numpy as np
import torch
import torchvision.transforms as tfs
import torchvision.utils as vutils
from PIL import Image
from tqdm import tqdm
from metrics import psnr, ssim
from models.C2PNet import C2PNet
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset_name', help='name of dataset', choices=['indoor', 'outdoor'],
default='indoor')
parser.add_argument('--save_dir', type=str, default='dehaze_images', help='dehaze images save path')
parser.add_argument('--save', action='store_true', help='save dehaze images')
opt = parser.parse_args()
dataset = opt.dataset_name
if opt.save:
if not os.path.exists(opt.save_dir):
os.mkdir(opt.save_dir)
output_dir = os.path.join(opt.save_dir, dataset)
print("pred_dir:", output_dir)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if dataset == 'indoor':
haze_dir = 'data/SOTS/indoor/hazy/'
clear_dir = 'data/SOTS/indoor/clear/'
model_dir = 'trained_models/ITS.pkl'
elif dataset == 'outdoor':
haze_dir = 'data/SOTS/outdoor/hazy/'
clear_dir = 'data/SOTS/outdoor/clear/'
model_dir = 'trained_models/OTS.pkl'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
net = C2PNet(gps=3, blocks=19)
ckp = torch.load(model_dir)
net = net.to(device)
net.load_state_dict(ckp['model'])
net.eval()
psnr_list = []
ssim_list = []
for im in tqdm(os.listdir(haze_dir)):
haze = Image.open(os.path.join(haze_dir, im)).convert('RGB')
if dataset == 'indoor' or dataset == 'outdoor':
clear_im = im.split('_')[0] + '.png'
else:
clear_im = im
clear = Image.open(os.path.join(clear_dir, clear_im)).convert('RGB')
haze1 = tfs.ToTensor()(haze)[None, ::]
haze1 = haze1.to(device)
clear_no = tfs.ToTensor()(clear)[None, ::]
with torch.no_grad():
pred = net(haze1)
ts = torch.squeeze(pred.clamp(0, 1).cpu())
pp = psnr(pred.cpu(), clear_no)
ss = ssim(pred.cpu(), clear_no)
psnr_list.append(pp)
ssim_list.append(ss)
if opt.save:
vutils.save_image(ts, os.path.join(output_dir, im))
print(f'Average PSNR is {np.mean(psnr_list)}')
print(f'Average SSIM is {np.mean(ssim_list)}')