Skip to content

Commit 4475859

Browse files
committed
add plotting of ROCs
1 parent 0a66df4 commit 4475859

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

evaluation/quantitative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dataset.grazpedwri_dataset import GrazPedWriDataset
88

9-
mode = ['end2end', 'lin_eval'][1]
9+
mode = ['end2end', 'lin_eval'][0]
1010

1111
metrics_kwargs = {'num_labels': GrazPedWriDataset.N_CLASSES, 'average': None}
1212
metrics = MetricCollection({

visualisation/plot_aurocs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from matplotlib import pyplot as plt
5+
from torchmetrics import classification
6+
7+
from dataset.grazpedwri_dataset import GrazPedWriDataset
8+
9+
models2plot = ['image', 'image_frac_loc', 'image_frac_loc_bin_seg_clip']
10+
11+
pred_dir = Path('evaluation/predictions')
12+
gt = torch.load(pred_dir / 'ground_truth.pt')
13+
14+
path_dict = {k.stem.rsplit('_', 1)[0]: k for k in pred_dir.iterdir() if not k.is_dir()}
15+
pred_dict = dict()
16+
for experiment in models2plot:
17+
pred = torch.load(path_dict[experiment])
18+
y = []
19+
y_hat = []
20+
for file_stem in gt.keys():
21+
y.append(gt[file_stem])
22+
y_hat.append(pred[file_stem])
23+
y = torch.stack(y).int()
24+
y_hat = torch.stack(y_hat)
25+
pred_dict[experiment] = y_hat
26+
27+
28+
roc = classification.MultilabelROC(num_labels=len(models2plot))
29+
for c in range(GrazPedWriDataset.N_CLASSES):
30+
cat_y_hat = torch.stack([pred_dict[e][:, c] for e in models2plot], dim=1)
31+
32+
roc.update(cat_y_hat, y[:, c].unsqueeze(1).expand(-1, len(models2plot)))
33+
fig, axs = roc.plot(score=True, labels=models2plot)
34+
axs.set_title('')
35+
plt.legend(fontsize='large')
36+
37+
fig.savefig(f'/home/ron/Documents/Konferenzen/BVM 2025/ROCs/roc_{GrazPedWriDataset.CLASS_LABELS[c].replace('/', '_')}.pdf',
38+
bbox_inches='tight', pad_inches=0)
39+
40+
roc.reset()
41+
plt.show()
42+
43+

0 commit comments

Comments
 (0)