Skip to content

Commit 0684b69

Browse files
committed
implement dataset for classification
1 parent 32d90c2 commit 0684b69

File tree

2 files changed

+154
-26
lines changed

2 files changed

+154
-26
lines changed

dataset/create_cv_splits.ipynb

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
"metadata": {
77
"collapsed": true,
88
"ExecuteTime": {
9-
"end_time": "2024-10-14T16:42:43.618315Z",
10-
"start_time": "2024-10-14T16:42:43.563966Z"
9+
"end_time": "2024-10-15T06:54:10.123002Z",
10+
"start_time": "2024-10-15T06:54:10.078224Z"
1111
}
1212
},
1313
"source": [
@@ -25,8 +25,8 @@
2525
{
2626
"metadata": {
2727
"ExecuteTime": {
28-
"end_time": "2024-10-14T16:42:43.633543Z",
29-
"start_time": "2024-10-14T16:42:43.619411Z"
28+
"end_time": "2024-10-15T06:54:10.138779Z",
29+
"start_time": "2024-10-15T06:54:10.123985Z"
3030
}
3131
},
3232
"cell_type": "code",
@@ -74,8 +74,8 @@
7474
{
7575
"metadata": {
7676
"ExecuteTime": {
77-
"end_time": "2024-10-14T16:42:43.644170Z",
78-
"start_time": "2024-10-14T16:42:43.634252Z"
77+
"end_time": "2024-10-15T06:54:10.146829Z",
78+
"start_time": "2024-10-15T06:54:10.139469Z"
7979
}
8080
},
8181
"cell_type": "code",
@@ -116,8 +116,8 @@
116116
{
117117
"metadata": {
118118
"ExecuteTime": {
119-
"end_time": "2024-10-14T16:42:43.653869Z",
120-
"start_time": "2024-10-14T16:42:43.645037Z"
119+
"end_time": "2024-10-15T06:54:10.168041Z",
120+
"start_time": "2024-10-15T06:54:10.147523Z"
121121
}
122122
},
123123
"cell_type": "code",
@@ -137,8 +137,8 @@
137137
{
138138
"metadata": {
139139
"ExecuteTime": {
140-
"end_time": "2024-10-14T16:42:44.314722Z",
141-
"start_time": "2024-10-14T16:42:43.654562Z"
140+
"end_time": "2024-10-15T06:54:10.837962Z",
141+
"start_time": "2024-10-15T06:54:10.169054Z"
142142
}
143143
},
144144
"cell_type": "code",
@@ -156,8 +156,8 @@
156156
{
157157
"metadata": {
158158
"ExecuteTime": {
159-
"end_time": "2024-10-14T16:42:44.320170Z",
160-
"start_time": "2024-10-14T16:42:44.315638Z"
159+
"end_time": "2024-10-15T06:54:10.842454Z",
160+
"start_time": "2024-10-15T06:54:10.838968Z"
161161
}
162162
},
163163
"cell_type": "code",
@@ -181,8 +181,8 @@
181181
{
182182
"metadata": {
183183
"ExecuteTime": {
184-
"end_time": "2024-10-14T16:42:44.353884Z",
185-
"start_time": "2024-10-14T16:42:44.320776Z"
184+
"end_time": "2024-10-15T06:54:10.876290Z",
185+
"start_time": "2024-10-15T06:54:10.843525Z"
186186
}
187187
},
188188
"cell_type": "code",
@@ -217,8 +217,8 @@
217217
{
218218
"metadata": {
219219
"ExecuteTime": {
220-
"end_time": "2024-10-14T16:42:44.358482Z",
221-
"start_time": "2024-10-14T16:42:44.354677Z"
220+
"end_time": "2024-10-15T06:54:10.881895Z",
221+
"start_time": "2024-10-15T06:54:10.877157Z"
222222
}
223223
},
224224
"cell_type": "code",
@@ -253,8 +253,8 @@
253253
{
254254
"metadata": {
255255
"ExecuteTime": {
256-
"end_time": "2024-10-14T16:42:44.549793Z",
257-
"start_time": "2024-10-14T16:42:44.359164Z"
256+
"end_time": "2024-10-15T06:54:11.074822Z",
257+
"start_time": "2024-10-15T06:54:10.883094Z"
258258
}
259259
},
260260
"cell_type": "code",
@@ -291,8 +291,8 @@
291291
{
292292
"metadata": {
293293
"ExecuteTime": {
294-
"end_time": "2024-10-14T16:42:44.560169Z",
295-
"start_time": "2024-10-14T16:42:44.550715Z"
294+
"end_time": "2024-10-15T06:54:11.090284Z",
295+
"start_time": "2024-10-15T06:54:11.076120Z"
296296
}
297297
},
298298
"cell_type": "code",
@@ -447,23 +447,22 @@
447447
{
448448
"metadata": {
449449
"ExecuteTime": {
450-
"end_time": "2024-10-14T16:42:44.633875Z",
451-
"start_time": "2024-10-14T16:42:44.560756Z"
450+
"end_time": "2024-10-15T06:54:11.156047Z",
451+
"start_time": "2024-10-15T06:54:11.091264Z"
452452
}
453453
},
454454
"cell_type": "code",
455455
"source": [
456-
"print('Known classes:', df_exploded['ao_classification'].unique())\n",
457-
"df_split.to_csv('/home/ron/Documents/AOClassification/data/dataset_cv_splits.csv', index=False)"
456+
"print('Known classes:', sorted(df_exploded['ao_classification'].unique()))\n",
457+
"df_split.to_csv('/home/ron/Documents/AOClassification/data/dataset_cv_splits.csv', index=False, mode='x')"
458458
],
459459
"id": "50fdb4a84921dec8",
460460
"outputs": [
461461
{
462462
"name": "stdout",
463463
"output_type": "stream",
464464
"text": [
465-
"Known classes: ['23r-M/2.1' '23-M/3.1' 'none' '23-M/2.1' '23u-E/7' '23r-E/2.1'\n",
466-
" '23r-M/3.1' '23u-M/2.1']\n"
465+
"Known classes: ['23-M/2.1', '23-M/3.1', '23r-E/2.1', '23r-M/2.1', '23r-M/3.1', '23u-E/7', '23u-M/2.1', 'none']\n"
467466
]
468467
}
469468
],

dataset/grazpedwri_dataset.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import pandas as pd
6+
import torch
7+
from PIL import Image
8+
from kornia.enhance import Normalize
9+
from pytorch_lightning import LightningDataModule
10+
from torch.utils.data import Dataset
11+
from torchvision.transforms.functional import to_tensor
12+
from tqdm import tqdm
13+
14+
15+
class GrazPedWriDataset(Dataset):
16+
# calculated over training split
17+
IMG_MEAN = 0.3505533917353781
18+
IMG_STD = 0.22763733675869177
19+
20+
RESCALE_HW = (384, 224)
21+
22+
CLASS_LABEL = ['23-M/2.1', '23-M/3.1', '23r-E/2.1', '23r-M/2.1', '23r-M/3.1', '23u-E/7', '23u-M/2.1', 'none']
23+
CLASS_IDX = {k: v for v, k in enumerate(CLASS_LABEL)}
24+
N_CLASSES = len(CLASS_LABEL)
25+
26+
def __init__(self, mode: str, fold: int = 0, number_training_samples: int | str = 'all'):
27+
super().__init__()
28+
# load data meta and other information
29+
self.df_meta = pd.read_csv('data/dataset_cv_splits.csv', index_col='filestem')
30+
# init ground truth parser considering the data split
31+
if mode == 'train':
32+
self.df_meta = self.df_meta[self.df_meta['fold'] != fold]
33+
elif mode == 'val':
34+
self.df_meta = self.df_meta[self.df_meta['fold'] == fold]
35+
else:
36+
raise ValueError(f'Unknown mode: {mode}')
37+
self.available_file_names = self.df_meta.index.tolist()
38+
39+
# get subset of training samples
40+
if mode == 'train' and number_training_samples != 'all':
41+
raise NotImplementedError('number_training_samples is not implemented for GrazPedWriDataset')
42+
elif mode != 'train' and number_training_samples != 'all':
43+
logging.warning(f'number_training_samples is not used for mode {mode}')
44+
45+
# load img into memory
46+
img_path = Path('data/img_only_front_all_left')
47+
self.data = dict()
48+
for file_name in tqdm(self.available_file_names, unit='img', desc=f'Loading data for {mode}'):
49+
# image
50+
img = Image.open(img_path.joinpath(file_name).with_suffix('.png')).convert('L')
51+
img = img.resize(self.RESCALE_HW[::-1], Image.BILINEAR)
52+
img = to_tensor(img)
53+
54+
# classification ground truth
55+
class_label: str = self.df_meta.loc[file_name, 'ao_classification']
56+
class_label: list[str] = class_label.split(';')
57+
y = torch.zeros(self.N_CLASSES)
58+
for c in class_label:
59+
c = c.strip()
60+
if c not in self.CLASS_IDX:
61+
continue
62+
else:
63+
y[self.CLASS_IDX[c]] = 1
64+
assert y.sum() > 0, f'No valid class found for {file_name} with {class_label}'
65+
66+
self.data[file_name] = {
67+
'file_name': file_name,
68+
'image': img,
69+
'y': y
70+
71+
}
72+
break
73+
74+
def __len__(self):
75+
return len(self.available_file_names)
76+
77+
def __getitem__(self, index):
78+
"""
79+
get item by index
80+
:param index: index of item
81+
:return: dict with keys ['image', 'mask', 'file_name']
82+
"""
83+
file_name = self.available_file_names[index]
84+
data_dict = self.data[file_name]
85+
86+
return data_dict
87+
88+
89+
class GrazPedWriDataModule(LightningDataModule):
90+
def __init__(self, fold: int = 0, batch_size: int = 32, number_training_samples: int | str = 'all'):
91+
super().__init__()
92+
self.n_train = number_training_samples
93+
self.fold = fold
94+
self.dl_kwargs = {'batch_size': batch_size, 'num_workers': 4, 'pin_memory': torch.cuda.is_available()}
95+
self.normalize = Normalize(mean=GrazPedWriDataset.IMG_MEAN, std=GrazPedWriDataset.IMG_STD)
96+
97+
def setup(self, stage: str = None):
98+
if stage == 'fit' or stage is None:
99+
self.train_dataset = GrazPedWriDataset('train', self.fold, self.n_train)
100+
self.val_dataset = GrazPedWriDataset('val', self.fold)
101+
if stage == 'test' or stage is None:
102+
self.test_dataset = GrazPedWriDataset('val', self.fold)
103+
104+
def train_dataloader(self):
105+
return torch.utils.data.DataLoader(self.train_dataset, shuffle=True, drop_last=True, **self.dl_kwargs)
106+
107+
def val_dataloader(self):
108+
return torch.utils.data.DataLoader(self.val_dataset, **self.dl_kwargs)
109+
110+
def test_dataloader(self):
111+
return torch.utils.data.DataLoader(self.test_dataset, **self.dl_kwargs)
112+
113+
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
114+
batch['image'] = self.normalize(batch['image'])
115+
return batch
116+
117+
118+
if __name__ == '__main__':
119+
import matplotlib.pyplot as plt
120+
from torch.utils.data import DataLoader
121+
122+
dataset = GrazPedWriDataset('val', fold=0)
123+
data = dataset[0]
124+
print(data['image'].shape)
125+
print(data['y'])
126+
plt.figure(data['file_name'])
127+
plt.imshow(data['image'].squeeze().numpy(), cmap='gray')
128+
plt.title(dataset.CLASS_LABEL[data['y'].argmax()])
129+
plt.show()

0 commit comments

Comments
 (0)