Skip to content

Commit c473ba9

Browse files
author
Maxim Zhiltsov
authored
Move dataset tests to a separate file (#74)
1 parent 54107ed commit c473ba9

File tree

2 files changed

+300
-288
lines changed

2 files changed

+300
-288
lines changed

tests/test_dataset.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import numpy as np
2+
3+
from unittest import TestCase
4+
5+
from datumaro.components.project import Environment
6+
from datumaro.components.extractor import (Extractor, DatasetItem,
7+
Label, Mask, Points, Polygon, PolyLine, Bbox, Caption,
8+
LabelCategories, AnnotationType, Transform
9+
)
10+
from datumaro.util.image import Image
11+
from datumaro.components.dataset_filter import \
12+
XPathDatasetFilter, XPathAnnotationsFilter, DatasetItemEncoder
13+
from datumaro.components.dataset import Dataset, DEFAULT_FORMAT
14+
from datumaro.util.test_utils import TestDir, compare_datasets
15+
16+
17+
class DatasetTest(TestCase):
18+
def test_create_from_extractors(self):
19+
class SrcExtractor1(Extractor):
20+
def __iter__(self):
21+
return iter([
22+
DatasetItem(id=1, subset='train', annotations=[
23+
Bbox(1, 2, 3, 4),
24+
Label(4),
25+
]),
26+
DatasetItem(id=1, subset='val', annotations=[
27+
Label(4),
28+
]),
29+
])
30+
31+
class SrcExtractor2(Extractor):
32+
def __iter__(self):
33+
return iter([
34+
DatasetItem(id=1, subset='val', annotations=[
35+
Label(5),
36+
]),
37+
])
38+
39+
class DstExtractor(Extractor):
40+
def __iter__(self):
41+
return iter([
42+
DatasetItem(id=1, subset='train', annotations=[
43+
Bbox(1, 2, 3, 4),
44+
Label(4),
45+
]),
46+
DatasetItem(id=1, subset='val', annotations=[
47+
Label(4),
48+
Label(5),
49+
]),
50+
])
51+
52+
dataset = Dataset.from_extractors(SrcExtractor1(), SrcExtractor2())
53+
54+
compare_datasets(self, DstExtractor(), dataset)
55+
56+
def test_can_create_from_iterable(self):
57+
class TestExtractor(Extractor):
58+
def __iter__(self):
59+
return iter([
60+
DatasetItem(id=1, subset='train', annotations=[
61+
Bbox(1, 2, 3, 4, label=2),
62+
Label(4),
63+
]),
64+
DatasetItem(id=1, subset='val', annotations=[
65+
Label(3),
66+
]),
67+
])
68+
69+
def categories(self):
70+
return { AnnotationType.label: LabelCategories.from_iterable(
71+
['a', 'b', 'c', 'd', 'e'])
72+
}
73+
74+
actual = Dataset.from_iterable([
75+
DatasetItem(id=1, subset='train', annotations=[
76+
Bbox(1, 2, 3, 4, label=2),
77+
Label(4),
78+
]),
79+
DatasetItem(id=1, subset='val', annotations=[
80+
Label(3),
81+
]),
82+
], categories=['a', 'b', 'c', 'd', 'e'])
83+
84+
compare_datasets(self, TestExtractor(), actual)
85+
86+
def test_can_save_and_load(self):
87+
source_dataset = Dataset.from_iterable([
88+
DatasetItem(id=1, annotations=[ Label(2) ]),
89+
], categories=['a', 'b', 'c'])
90+
91+
with TestDir() as test_dir:
92+
source_dataset.save(test_dir)
93+
94+
loaded_dataset = Dataset.load(test_dir)
95+
96+
compare_datasets(self, source_dataset, loaded_dataset)
97+
98+
def test_can_detect(self):
99+
env = Environment()
100+
env.importers.items = {DEFAULT_FORMAT: env.importers[DEFAULT_FORMAT]}
101+
env.extractors.items = {DEFAULT_FORMAT: env.extractors[DEFAULT_FORMAT]}
102+
103+
dataset = Dataset.from_iterable([
104+
DatasetItem(id=1, annotations=[ Label(2) ]),
105+
], categories=['a', 'b', 'c'])
106+
107+
with TestDir() as test_dir:
108+
dataset.save(test_dir)
109+
110+
detected_format = Dataset.detect(test_dir, env=env)
111+
112+
self.assertEqual(DEFAULT_FORMAT, detected_format)
113+
114+
def test_can_detect_and_import(self):
115+
env = Environment()
116+
env.importers.items = {DEFAULT_FORMAT: env.importers[DEFAULT_FORMAT]}
117+
env.extractors.items = {DEFAULT_FORMAT: env.extractors[DEFAULT_FORMAT]}
118+
119+
source_dataset = Dataset.from_iterable([
120+
DatasetItem(id=1, annotations=[ Label(2) ]),
121+
], categories=['a', 'b', 'c'])
122+
123+
with TestDir() as test_dir:
124+
source_dataset.save(test_dir)
125+
126+
imported_dataset = Dataset.import_from(test_dir, env=env)
127+
128+
compare_datasets(self, source_dataset, imported_dataset)
129+
130+
def test_can_export_by_string_format_name(self):
131+
env = Environment()
132+
env.converters.items = {'qq': env.converters[DEFAULT_FORMAT]}
133+
134+
dataset = Dataset.from_iterable([
135+
DatasetItem(id=1, annotations=[ Label(2) ]),
136+
], categories=['a', 'b', 'c'], env=env)
137+
138+
with TestDir() as test_dir:
139+
dataset.export('qq', save_dir=test_dir)
140+
141+
def test_can_transform_by_string_name(self):
142+
expected = Dataset.from_iterable([
143+
DatasetItem(id=1, annotations=[ Label(2) ], attributes={'qq': 1}),
144+
], categories=['a', 'b', 'c'])
145+
146+
class TestTransform(Transform):
147+
def transform_item(self, item):
148+
return self.wrap_item(item, attributes={'qq': 1})
149+
150+
env = Environment()
151+
env.transforms.items = {'qq': TestTransform}
152+
153+
dataset = Dataset.from_iterable([
154+
DatasetItem(id=1, annotations=[ Label(2) ]),
155+
], categories=['a', 'b', 'c'], env=env)
156+
157+
actual = dataset.transform('qq')
158+
159+
self.assertTrue(isinstance(actual, Dataset))
160+
self.assertEqual(env, actual.env)
161+
compare_datasets(self, expected, actual)
162+
163+
164+
class DatasetItemTest(TestCase):
165+
def test_ctor_requires_id(self):
166+
with self.assertRaises(Exception):
167+
# pylint: disable=no-value-for-parameter
168+
DatasetItem()
169+
# pylint: enable=no-value-for-parameter
170+
171+
@staticmethod
172+
def test_ctors_with_image():
173+
for args in [
174+
{ 'id': 0, 'image': None },
175+
{ 'id': 0, 'image': 'path.jpg' },
176+
{ 'id': 0, 'image': np.array([1, 2, 3]) },
177+
{ 'id': 0, 'image': lambda f: np.array([1, 2, 3]) },
178+
{ 'id': 0, 'image': Image(data=np.array([1, 2, 3])) },
179+
]:
180+
DatasetItem(**args)
181+
182+
183+
class DatasetFilterTest(TestCase):
184+
@staticmethod
185+
def test_item_representations():
186+
item = DatasetItem(id=1, subset='subset', path=['a', 'b'],
187+
image=np.ones((5, 4, 3)),
188+
annotations=[
189+
Label(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2),
190+
Caption('hello', id=1),
191+
Caption('world', group=5),
192+
Label(2, id=3, attributes={ 'x': 1, 'y': '2' }),
193+
Bbox(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }),
194+
Bbox(5, 6, 7, 8, id=5, group=5),
195+
Points([1, 2, 2, 0, 1, 1], label=0, id=5),
196+
Mask(id=5, image=np.ones((3, 2))),
197+
Mask(label=3, id=5, image=np.ones((2, 3))),
198+
PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11),
199+
Polygon([1, 2, 3, 4, 5, 6, 7, 8]),
200+
]
201+
)
202+
203+
encoded = DatasetItemEncoder.encode(item)
204+
DatasetItemEncoder.to_string(encoded)
205+
206+
def test_item_filter_can_be_applied(self):
207+
class TestExtractor(Extractor):
208+
def __iter__(self):
209+
for i in range(4):
210+
yield DatasetItem(id=i, subset='train')
211+
212+
extractor = TestExtractor()
213+
214+
filtered = XPathDatasetFilter(extractor, '/item[id > 1]')
215+
216+
self.assertEqual(2, len(filtered))
217+
218+
def test_annotations_filter_can_be_applied(self):
219+
class SrcExtractor(Extractor):
220+
def __iter__(self):
221+
return iter([
222+
DatasetItem(id=0),
223+
DatasetItem(id=1, annotations=[
224+
Label(0),
225+
Label(1),
226+
]),
227+
DatasetItem(id=2, annotations=[
228+
Label(0),
229+
Label(2),
230+
]),
231+
])
232+
233+
class DstExtractor(Extractor):
234+
def __iter__(self):
235+
return iter([
236+
DatasetItem(id=0),
237+
DatasetItem(id=1, annotations=[
238+
Label(0),
239+
]),
240+
DatasetItem(id=2, annotations=[
241+
Label(0),
242+
]),
243+
])
244+
245+
extractor = SrcExtractor()
246+
247+
filtered = XPathAnnotationsFilter(extractor,
248+
'/item/annotation[label_id = 0]')
249+
250+
self.assertListEqual(list(filtered), list(DstExtractor()))
251+
252+
def test_annotations_filter_can_remove_empty_items(self):
253+
class SrcExtractor(Extractor):
254+
def __iter__(self):
255+
return iter([
256+
DatasetItem(id=0),
257+
DatasetItem(id=1, annotations=[
258+
Label(0),
259+
Label(1),
260+
]),
261+
DatasetItem(id=2, annotations=[
262+
Label(0),
263+
Label(2),
264+
]),
265+
])
266+
267+
class DstExtractor(Extractor):
268+
def __iter__(self):
269+
return iter([
270+
DatasetItem(id=2, annotations=[
271+
Label(2),
272+
]),
273+
])
274+
275+
extractor = SrcExtractor()
276+
277+
filtered = XPathAnnotationsFilter(extractor,
278+
'/item/annotation[label_id = 2]', remove_empty=True)
279+
280+
self.assertListEqual(list(filtered), list(DstExtractor()))

0 commit comments

Comments
 (0)