|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import pandas as pd |
1 | 4 | import torch |
2 | 5 | from torchmetrics import classification, MetricCollection |
| 6 | + |
3 | 7 | from dataset.grazpedwri_dataset import GrazPedWriDataset |
4 | | -import pandas as pd |
5 | | -from pathlib import Path |
6 | | -from evaluation.best_shot_accuracy import BestShotAccuracy |
7 | 8 |
|
8 | | -mode = ['end2end', 'lin_eval'][0] |
| 9 | +mode = ['end2end', 'lin_eval'][1] |
9 | 10 |
|
10 | 11 | metrics_kwargs = {'num_labels': GrazPedWriDataset.N_CLASSES, 'average': None} |
11 | 12 | metrics = MetricCollection({ |
12 | | - "Acc": classification.MultilabelAccuracy(**metrics_kwargs), |
| 13 | + "Accuracy": classification.MultilabelAccuracy(**metrics_kwargs), |
13 | 14 | "F1": classification.MultilabelF1Score(**metrics_kwargs), |
14 | 15 | "Precision": classification.MultilabelPrecision(**metrics_kwargs), |
15 | 16 | "Recall": classification.MultilabelRecall(**metrics_kwargs), |
16 | | - "AUROC": classification.MultilabelAUROC(**metrics_kwargs), |
17 | | - "BestShotAcc": BestShotAccuracy() |
| 17 | + "AUROC": classification.MultilabelAUROC(**metrics_kwargs) |
18 | 18 | }) |
19 | 19 | pred_dir = Path('evaluation/predictions') |
20 | 20 | gt = torch.load(pred_dir / 'ground_truth.pt') |
21 | 21 |
|
22 | | -mean_df = pd.DataFrame(columns=['Experiment', 'Acc', 'BestShotAcc', 'F1', 'Precision', 'Recall', 'AUROC']) |
23 | | -experiment_df = pd.DataFrame(columns=['Experiment', 'Acc', 'F1', 'Precision', 'Recall', 'AUROC', 'AO_Class']) |
| 22 | +mean_df = pd.DataFrame(columns=['Experiment', 'Accuracy', 'F1', 'Precision', 'Recall', 'AUROC']) |
| 23 | +experiment_df = pd.DataFrame(columns=['Experiment', 'Accuracy', 'F1', 'Precision', 'Recall', 'AUROC', 'AO_Class']) |
24 | 24 | for experiment in pred_dir.iterdir(): |
25 | 25 | is_line_eval = experiment.stem.startswith('LE') |
26 | 26 | match_mode = (mode == 'lin_eval' and is_line_eval) or (mode == 'end2end' and not is_line_eval) |
27 | | - if experiment.stem == 'ground_truth' or experiment.is_dir() or not match_mode: |
| 27 | + contains_mult_seg = 'mult_seg' in experiment.stem |
| 28 | + if experiment.stem == 'ground_truth' or experiment.is_dir() or not match_mode or contains_mult_seg: |
28 | 29 | continue |
29 | 30 |
|
30 | 31 | pred = torch.load(experiment) |
|
39 | 40 | performance = metrics(y_hat, y) |
40 | 41 | mean_df = pd.concat([mean_df, pd.DataFrame({ |
41 | 42 | 'Experiment': experiment.stem.rsplit('_', 1)[0], |
42 | | - 'Acc': performance['Acc'].mean().item(), |
43 | | - "BestShotAcc": performance['BestShotAcc'].item(), |
| 43 | + 'Accuracy': performance['Accuracy'].mean().item(), |
44 | 44 | 'F1': performance['F1'].mean().item(), |
45 | 45 | 'Precision': performance['Precision'].mean().item(), |
46 | 46 | 'Recall': performance['Recall'].mean().item(), |
|
49 | 49 |
|
50 | 50 | experiment_df = pd.concat([experiment_df, pd.DataFrame({ |
51 | 51 | 'Experiment': experiment.stem.rsplit('_', 1)[0], |
52 | | - 'Acc': performance['Acc'].tolist(), |
| 52 | + 'Accuracy': performance['Accuracy'].tolist(), |
53 | 53 | 'F1': performance['F1'].tolist(), |
54 | 54 | 'Precision': performance['Precision'].tolist(), |
55 | 55 | 'Recall': performance['Recall'].tolist(), |
|
0 commit comments