-
Notifications
You must be signed in to change notification settings - Fork 158
Expand file tree
/
Copy pathextractor.py
More file actions
334 lines (265 loc) · 10.8 KB
/
extractor.py
File metadata and controls
334 lines (265 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from glob import iglob
from typing import Callable, Dict, Iterable, List, Optional
import os
import os.path as osp
from attr import attrib, attrs
import attr
import numpy as np
from datumaro.components.annotation import AnnotationType, Categories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.errors import DatasetNotFoundError
from datumaro.util import is_method_redefined
from datumaro.util.attrs_util import default_if_none, not_empty
from datumaro.util.image import Image
# Re-export some names from .annotation for backwards compatibility.
import datumaro.components.annotation # isort:skip
for _name in [
'Annotation', 'AnnotationType', 'Bbox', 'Caption', 'Categories',
'CompiledMask', 'Cuboid3d', 'Label', 'LabelCategories', 'Mask',
'MaskCategories', 'Points', 'PointsCategories', 'Polygon', 'RleMask',
]:
globals()[_name] = getattr(datumaro.components.annotation, _name)
DEFAULT_SUBSET_NAME = 'default'
@attrs
class DatasetItem:
id = attrib(converter=lambda x: str(x).replace('\\', '/'),
type=str, validator=not_empty)
annotations = attrib(factory=list, validator=default_if_none(list))
subset = attrib(converter=lambda v: v or DEFAULT_SUBSET_NAME,
type=str, default=None)
# Currently unused
path = attrib(factory=list, validator=default_if_none(list))
# TODO: introduce "media" field with type info. Replace image and pcd.
image = attrib(type=Optional[Image], default=None)
# TODO: introduce pcd type like Image
point_cloud = attrib(converter=lambda x: \
str(x).replace('\\', '/') if x else None,
type=Optional[str], default=None)
related_images = attrib(type=List[Image], default=None)
def __attrs_post_init__(self):
if (self.has_image and self.has_point_cloud):
raise ValueError("Can't set both image and point cloud info")
if self.related_images and not self.has_point_cloud:
raise ValueError("Related images require point cloud")
def _image_converter(image):
if callable(image) or isinstance(image, np.ndarray):
image = Image(data=image)
elif isinstance(image, str):
image = Image(path=image)
assert image is None or isinstance(image, Image), type(image)
return image
image.converter = _image_converter
def _related_image_converter(images):
return list(map(__class__._image_converter, images or []))
related_images.converter = _related_image_converter
@point_cloud.validator
def _point_cloud_validator(self, attribute, pcd):
assert pcd is None or isinstance(pcd, str), type(pcd)
attributes = attrib(factory=dict, validator=default_if_none(dict))
@property
def has_image(self):
return self.image is not None
@property
def has_point_cloud(self):
return self.point_cloud is not None
def wrap(item, **kwargs):
return attr.evolve(item, **kwargs)
CategoriesInfo = Dict[AnnotationType, Categories]
class IExtractor:
def __iter__(self) -> Iterable[DatasetItem]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
def __bool__(self): # avoid __len__ use for truth checking
return True
def subsets(self) -> Dict[str, 'IExtractor']:
raise NotImplementedError()
def get_subset(self, name) -> 'IExtractor':
raise NotImplementedError()
def categories(self) -> CategoriesInfo:
raise NotImplementedError()
def get(self, id, subset=None) -> Optional[DatasetItem]:
raise NotImplementedError()
class ExtractorBase(IExtractor):
def __init__(self, length=None, subsets=None):
self._length = length
self._subsets = subsets
def _init_cache(self):
subsets = set()
length = -1
for length, item in enumerate(self):
subsets.add(item.subset)
length += 1
if self._length is None:
self._length = length
if self._subsets is None:
self._subsets = subsets
def __len__(self):
if self._length is None:
self._init_cache()
return self._length
def subsets(self) -> Dict[str, IExtractor]:
if self._subsets is None:
self._init_cache()
return {name or DEFAULT_SUBSET_NAME: self.get_subset(name)
for name in self._subsets}
def get_subset(self, name):
if self._subsets is None:
self._init_cache()
if name in self._subsets:
if len(self._subsets) == 1:
return self
subset = self.select(lambda item: item.subset == name)
subset._subsets = [name]
return subset
else:
raise KeyError("Unknown subset '%s', available subsets: %s" % \
(name, set(self._subsets)))
def transform(self, method, *args, **kwargs):
return method(self, *args, **kwargs)
def select(self, pred):
class _DatasetFilter(ExtractorBase):
def __iter__(_):
return filter(pred, iter(self))
def categories(_):
return self.categories()
return _DatasetFilter()
def categories(self):
return {}
def get(self, id, subset=None):
subset = subset or DEFAULT_SUBSET_NAME
for item in self:
if item.id == id and item.subset == subset:
return item
return None
class Extractor(ExtractorBase, CliPlugin):
"""
A base class for user-defined and built-in extractors.
Should be used in cases, where SourceExtractor is not enough,
or its use makes problems with performance, implementation etc.
"""
class SourceExtractor(Extractor):
"""
A base class for simple, single-subset extractors.
Should be used by default for user-defined extractors.
"""
def __init__(self, length=None, subset=None):
self._subset = subset or DEFAULT_SUBSET_NAME
super().__init__(length=length, subsets=[self._subset])
self._categories = {}
self._items = []
def categories(self):
return self._categories
def __iter__(self):
yield from self._items
def __len__(self):
return len(self._items)
def get(self, id, subset=None):
assert subset == self._subset, '%s != %s' % (subset, self._subset)
return super().get(id, subset or self._subset)
class Importer(CliPlugin):
@classmethod
def detect(cls, path):
if not path or not osp.exists(path):
return False
return len(cls.find_sources_with_params(path)) != 0
@classmethod
def find_sources(cls, path) -> List[Dict]:
raise NotImplementedError()
@classmethod
def find_sources_with_params(cls, path, **extra_params) -> List[Dict]:
return cls.find_sources(path)
def __call__(self, path, **extra_params):
if not path or not osp.exists(path):
raise DatasetNotFoundError(path)
found_sources = self.find_sources_with_params(osp.normpath(path), **extra_params)
if not found_sources:
raise DatasetNotFoundError(path)
sources = []
for desc in found_sources:
params = dict(extra_params)
params.update(desc.get('options', {}))
desc['options'] = params
sources.append(desc)
return sources
@classmethod
def _find_sources_recursive(cls, path: str, ext: Optional[str],
extractor_name: str, filename: str = '*', dirname: str = '',
file_filter: Optional[Callable[[str], bool]] = None,
max_depth: int = 3):
"""
Finds sources in the specified location, using the matching pattern
to filter file names and directories.
Supposed to be used, and to be the only call in subclasses.
Parameters:
- path - a directory or file path, where sources need to be found.
- ext - file extension to match. To match directories,
set this parameter to None or ''. Comparison is case-independent,
a starting dot is not required.
- extractor_name - the name of the associated Extractor type
- filename - a glob pattern for file names
- dirname - a glob pattern for filename prefixes
- file_filter - a callable (abspath: str) -> bool, to filter paths found
- max_depth - the maximum depth for recursive search.
Returns: a list of source configurations
(i.e. Extractor type names and c-tor parameters)
"""
if ext:
if not ext.startswith('.'):
ext = '.' + ext
ext = ext.lower()
if (path.lower().endswith(ext) and osp.isfile(path)) or \
(not ext and dirname and osp.isdir(path) and \
os.sep + osp.normpath(dirname.lower()) + os.sep in \
osp.abspath(path.lower()) + os.sep):
sources = [{'url': path, 'format': extractor_name}]
else:
sources = []
for d in range(max_depth + 1):
sources.extend({'url': p, 'format': extractor_name} for p in
iglob(osp.join(path, *('*' * d), dirname, filename + ext))
if (callable(file_filter) and file_filter(p)) \
or (not callable(file_filter)))
if sources:
break
return sources
class Transform(ExtractorBase, CliPlugin):
"""
A base class for dataset transformations that change dataset items
or their annotations.
"""
@staticmethod
def wrap_item(item, **kwargs):
return item.wrap(**kwargs)
def __init__(self, extractor):
super().__init__()
self._extractor = extractor
def categories(self):
return self._extractor.categories()
def subsets(self):
if self._subsets is None:
self._subsets = set(self._extractor.subsets())
return super().subsets()
def __len__(self):
assert self._length in {None, 'parent'} or isinstance(self._length, int)
if self._length is None and \
not is_method_redefined('__iter__', Transform, self) \
or self._length == 'parent':
self._length = len(self._extractor)
return super().__len__()
class ItemTransform(Transform):
def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
"""
Returns a modified copy of the input item.
Avoid changing and returning the input item, because it can lead to
unexpected problems. Use wrap_item() or item.wrap() to simplify copying.
"""
raise NotImplementedError()
def __iter__(self):
for item in self._extractor:
item = self.transform_item(item)
if item is not None:
yield item