forked from yhenon/pytorch-retinanet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcsv_validation.py
More file actions
59 lines (46 loc) · 2.14 KB
/
csv_validation.py
File metadata and controls
59 lines (46 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# In csv_validation.py
import argparse
import torch
from torchvision import transforms
from retinanet import model
from retinanet.dataloader import CSVDataset, Preprocess # Use the robust Preprocess class
from retinanet import csv_eval
def main(args=None):
parser = argparse.ArgumentParser(description='Simple evaluation script for a RetinaNet network.')
parser.add_argument('--csv_annotations', help='Path to CSV annotations file for validation')
parser.add_argument('--model_path', help='Path to saved model state_dict (.pt file)', type=str)
parser.add_argument('--class_list', help='Path to classlist csv', type=str)
# MODIFIED: Add a score threshold argument for precision/recall calculation
parser.add_argument('--iou_threshold', help='IOU threshold used for mAP', type=float, default=0.5)
parser.add_argument('--score_threshold', help='Score threshold for Precision/Recall calculation', type=float, default=0.5)
parser = parser.parse_args(args)
dataset_val = CSVDataset(train_file=parser.csv_annotations, class_list=parser.class_list, transform=Preprocess())
# Create model and load weights
print("Creating model structure...")
num_classes = dataset_val.num_classes()
retinanet = model.efficientnet_b0_retinanet(num_classes=num_classes)
print(f"Loading weights from {parser.model_path}...")
retinanet.load_state_dict(torch.load(parser.model_path))
# Setup device and eval mode
if torch.cuda.is_available():
retinanet = retinanet.cuda()
retinanet.eval()
if hasattr(retinanet, 'module'):
retinanet.module.freeze_bn()
else:
retinanet.freeze_bn()
print("Evaluating model...")
# MODIFIED: Pass the score threshold to the evaluate function
results = csv_eval.evaluate(
dataset_val,
retinanet,
iou_threshold=parser.iou_threshold,
score_threshold=parser.score_threshold
)
# MODIFIED: Print the richer results dictionary
print("\n--- Final Metrics ---")
for key, value in results.items():
print(f"{key.upper()}: {value:.4f}")
print("---------------------")
if __name__ == '__main__':
main()