Skip to content

Commit 54107ed

Browse files
author
Maxim Zhiltsov
authored
Extend Dataset class, allow Extractor-based datasets (#71)
* Add auto dataset format detection to Environment * Extend dataset class - add save and load - add export and import_from - add detect - add run_model * Allow extractor import in projects * Allow extractor imports in CLI/project import * Make dataset operations form closed set * Make dataset transforms eager (and all other operations too) * Update convert command implementation * Move default format declaration
1 parent e7759e8 commit 54107ed

File tree

10 files changed

+334
-90
lines changed

10 files changed

+334
-90
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- `VGGFace2` dataset format (<https://github.com/openvinotoolkit/datumaro/pull/69>)
1515

1616
### Changed
17-
-
17+
- `Dataset` class extended with new operations: `save`, `load`, `export`, `import_from`, `detect`, `run_model` (<https://github.com/openvinotoolkit/datumaro/pull/71>)
18+
- `Dataset` operations return `Dataset` instances, allowing to chain operations (<https://github.com/openvinotoolkit/datumaro/pull/71>)
19+
- Allowed importing `Extractor`-only defined formats (in `Project.import_from`, `dataset.import_from` and CLI/`project import`)
1820

1921
### Deprecated
2022
-

datumaro/cli/commands/convert.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os.path as osp
1010

1111
from datumaro.components.project import Environment
12+
from datumaro.components.dataset import Dataset
1213

1314
from ..contexts.project import FilterModes
1415
from ..util import CliException, MultilineFormatter, make_file_name
@@ -68,46 +69,24 @@ def convert_command(args):
6869
raise CliException("Converter for format '%s' is not found" % \
6970
args.output_format)
7071
extra_args = converter.from_cmdline(args.extra_args)
71-
def converter_proxy(extractor, save_dir):
72-
return converter.convert(extractor, save_dir, **extra_args)
7372

7473
filter_args = FilterModes.make_filter_args(args.filter_mode)
7574

75+
fmt = args.input_format
7676
if not args.input_format:
77-
matches = []
78-
for format_name in env.importers.items:
79-
log.debug("Checking '%s' format...", format_name)
80-
importer = env.make_importer(format_name)
81-
try:
82-
match = importer.detect(args.source)
83-
if match:
84-
log.debug("format matched")
85-
matches.append((format_name, importer))
86-
except NotImplementedError:
87-
log.debug("Format '%s' does not support auto detection.",
88-
format_name)
89-
77+
matches = env.detect_dataset(args.source)
9078
if len(matches) == 0:
9179
log.error("Failed to detect dataset format. "
9280
"Try to specify format with '-if/--input-format' parameter.")
9381
return 1
9482
elif len(matches) != 1:
9583
log.error("Multiple formats match the dataset: %s. "
9684
"Try to specify format with '-if/--input-format' parameter.",
97-
', '.join(m[0] for m in matches))
85+
', '.join(matches))
9886
return 2
9987

100-
format_name, importer = matches[0]
101-
args.input_format = format_name
88+
fmt = matches[0]
10289
log.info("Source dataset format detected as '%s'", args.input_format)
103-
else:
104-
try:
105-
importer = env.make_importer(args.input_format)
106-
if hasattr(importer, 'from_cmdline'):
107-
extra_args = importer.from_cmdline()
108-
except KeyError:
109-
raise CliException("Importer for format '%s' is not found" % \
110-
args.input_format)
11190

11291
source = osp.abspath(args.source)
11392

@@ -121,15 +100,12 @@ def converter_proxy(extractor, save_dir):
121100
(osp.basename(source), make_file_name(args.output_format)))
122101
dst_dir = osp.abspath(dst_dir)
123102

124-
project = importer(source)
125-
dataset = project.make_dataset()
103+
dataset = Dataset.import_from(source, fmt)
126104

127105
log.info("Exporting the dataset")
128-
dataset.export_project(
129-
save_dir=dst_dir,
130-
converter=converter_proxy,
131-
filter_expr=args.filter,
132-
**filter_args)
106+
if args.filter:
107+
dataset = dataset.filter(args.filter, **filter_args)
108+
dataset.export(args.output_format, save_dir=dst_dir, **extra_args)
133109

134110
log.info("Dataset exported to '%s' as '%s'" % \
135111
(dst_dir, args.output_format))

datumaro/cli/contexts/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os.path as osp
1010
import re
1111

12-
from datumaro.components.config import DEFAULT_FORMAT
12+
from datumaro.components.dataset import DEFAULT_FORMAT
1313
from datumaro.components.project import Environment
1414

1515
from ...util import CliException, MultilineFormatter, add_subparser

datumaro/cli/contexts/project/__init__.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -172,50 +172,43 @@ def import_command(args):
172172
log.info("Importing project from '%s'" % args.source)
173173

174174
extra_args = {}
175+
fmt = args.format
175176
if not args.format:
176177
if args.extra_args:
177178
raise CliException("Extra args can not be used without format")
178179

179180
log.info("Trying to detect dataset format...")
180181

181-
matches = []
182-
for format_name in env.importers.items:
183-
log.debug("Checking '%s' format...", format_name)
184-
importer = env.make_importer(format_name)
185-
try:
186-
match = importer.detect(args.source)
187-
if match:
188-
log.debug("format matched")
189-
matches.append((format_name, importer))
190-
except NotImplementedError:
191-
log.debug("Format '%s' does not support auto detection.",
192-
format_name)
193-
182+
matches = env.detect_dataset(args.source)
194183
if len(matches) == 0:
195184
log.error("Failed to detect dataset format automatically. "
196185
"Try to specify format with '-f/--format' parameter.")
197186
return 1
198187
elif len(matches) != 1:
199188
log.error("Multiple formats match the dataset: %s. "
200189
"Try to specify format with '-f/--format' parameter.",
201-
', '.join(m[0] for m in matches))
190+
', '.join(matches))
202191
return 2
203192

204-
format_name, importer = matches[0]
205-
args.format = format_name
206-
else:
207-
try:
208-
importer = env.make_importer(args.format)
209-
if hasattr(importer, 'from_cmdline'):
210-
extra_args = importer.from_cmdline(args.extra_args)
211-
except KeyError:
212-
raise CliException("Importer for format '%s' is not found" % \
213-
args.format)
214-
215-
log.info("Importing project as '%s'" % args.format)
216-
217-
source = osp.abspath(args.source)
218-
project = importer(source, **extra_args)
193+
fmt = matches[0]
194+
elif args.extra_args:
195+
if fmt in env.importers:
196+
arg_parser = env.importers[fmt]
197+
elif fmt in env.extractors:
198+
arg_parser = env.extractors[fmt]
199+
else:
200+
raise CliException("Unknown format '%s'. A format can be added"
201+
"by providing an Extractor and Importer plugins" % fmt)
202+
203+
if hasattr(arg_parser, 'from_cmdline'):
204+
extra_args = arg_parser.from_cmdline(args.extra_args)
205+
else:
206+
raise CliException("Format '%s' does not accept "
207+
"extra parameters" % fmt)
208+
209+
log.info("Importing project as '%s'" % fmt)
210+
211+
project = Project.import_from(osp.abspath(args.source), fmt, **extra_args)
219212
project.config.project_name = project_name
220213
project.config.project_dir = project_dir
221214

datumaro/components/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,3 @@ def set(self, key, value):
232232
return super().set(key, value)
233233
else:
234234
return super().set(key, value)
235-
236-
237-
DEFAULT_FORMAT = 'datumaro'

datumaro/components/dataset.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1-
# Copyright (C) 2020 Intel Corporation
1+
# Copyright (C) 2020-2021 Intel Corporation
22
#
33
# SPDX-License-Identifier: MIT
44

55
from collections import OrderedDict, defaultdict
66
from typing import Iterable, Union, Dict, List
7+
import logging as log
8+
import os
9+
import os.path as osp
10+
import shutil
711

812
from datumaro.components.extractor import (Extractor, LabelCategories,
913
AnnotationType, DatasetItem, DEFAULT_SUBSET_NAME)
1014
from datumaro.components.dataset_filter import \
1115
XPathDatasetFilter, XPathAnnotationsFilter
16+
from datumaro.components.environment import Environment
17+
from datumaro.util import error_rollback
18+
from datumaro.util.log_utils import logging_disabled
1219

1320

21+
DEFAULT_FORMAT = 'datumaro'
22+
1423
class Dataset(Extractor):
1524
class Subset(Extractor):
1625
def __init__(self, parent):
@@ -28,7 +37,8 @@ def categories(self):
2837

2938
@classmethod
3039
def from_iterable(cls, iterable: Iterable[DatasetItem],
31-
categories: Union[Dict, List[str]] = None):
40+
categories: Union[Dict, List[str]] = None,
41+
env: Environment = None):
3242
if isinstance(categories, list):
3343
categories = { AnnotationType.label:
3444
LabelCategories.from_iterable(categories)
@@ -44,12 +54,12 @@ def __iter__(self):
4454
def categories(self):
4555
return categories
4656

47-
return cls.from_extractors(_extractor())
57+
return cls.from_extractors(_extractor(), env=env)
4858

4959
@classmethod
50-
def from_extractors(cls, *sources):
60+
def from_extractors(cls, *sources, env=None):
5161
categories = cls._merge_categories(s.categories() for s in sources)
52-
dataset = Dataset(categories=categories)
62+
dataset = Dataset(categories=categories, env=env)
5363

5464
# merge items
5565
subsets = defaultdict(lambda: cls.Subset(dataset))
@@ -67,9 +77,12 @@ def from_extractors(cls, *sources):
6777
dataset._subsets = dict(subsets)
6878
return dataset
6979

70-
def __init__(self, categories=None):
80+
def __init__(self, categories=None, env=None):
7181
super().__init__()
7282

83+
assert env is None or isinstance(env, Environment), env
84+
self._env = env
85+
7386
self._subsets = {}
7487

7588
if not categories:
@@ -183,4 +196,94 @@ def _merge_anno(a, b):
183196
def _merge_categories(sources):
184197
# TODO: implement properly with merging and annotations remapping
185198
from .operations import merge_categories
186-
return merge_categories(sources)
199+
return merge_categories(sources)
200+
201+
@error_rollback('on_error', implicit=True)
202+
def export(self, converter, save_dir, **kwargs):
203+
if isinstance(converter, str):
204+
converter = self.env.make_converter(converter)
205+
206+
save_dir = osp.abspath(save_dir)
207+
if not osp.exists(save_dir):
208+
on_error.do(shutil.rmtree, save_dir, ignore_errors=True)
209+
os.makedirs(save_dir, exist_ok=True)
210+
converter(self, save_dir=save_dir, **kwargs)
211+
212+
def transform(self, method, *args, **kwargs):
213+
if isinstance(method, str):
214+
method = self.env.make_transform(method)
215+
216+
result = super().transform(method, *args, **kwargs)
217+
return Dataset.from_extractors(result, env=self._env)
218+
219+
def run_model(self, model, batch_size=1):
220+
from datumaro.components.launcher import Launcher, ModelTransform
221+
if isinstance(model, Launcher):
222+
return self.transform(ModelTransform, launcher=model,
223+
batch_size=batch_size)
224+
elif isinstance(model, ModelTransform):
225+
return self.transform(model, batch_size=batch_size)
226+
else:
227+
raise TypeError('Unexpected model argument type: %s' % type(model))
228+
229+
@property
230+
def env(self):
231+
if not self._env:
232+
self._env = Environment()
233+
return self._env
234+
235+
def save(self, save_dir, **kwargs):
236+
self.export(DEFAULT_FORMAT, save_dir=save_dir, **kwargs)
237+
238+
@classmethod
239+
def load(cls, path, **kwargs):
240+
return cls.import_from(path, format=DEFAULT_FORMAT, **kwargs)
241+
242+
@classmethod
243+
def import_from(cls, path, format=None, env=None, **kwargs): #pylint: disable=redefined-builtin
244+
from datumaro.components.config_model import Source
245+
246+
if env is None:
247+
env = Environment()
248+
249+
# TODO: remove importers, put this logic into extractors
250+
if not format:
251+
format = cls.detect(path, env)
252+
if format in env.importers:
253+
importer = env.make_importer(format)
254+
with logging_disabled(log.INFO):
255+
project = importer(path, **kwargs)
256+
detected_sources = list(project.config.sources.values())
257+
elif format in env.extractors:
258+
detected_sources = [{
259+
'url': path, 'format': format, 'options': kwargs
260+
}]
261+
else:
262+
raise Exception("Unknown source format '%s'. To make it "
263+
"available, add the corresponding Extractor implementation "
264+
"to the environment" % format)
265+
266+
extractors = []
267+
for src_conf in detected_sources:
268+
if not isinstance(src_conf, Source):
269+
src_conf = Source(src_conf)
270+
extractors.append(env.make_extractor(
271+
src_conf.format, src_conf.url, **src_conf.options
272+
))
273+
274+
return cls.from_extractors(*extractors)
275+
276+
@staticmethod
277+
def detect(path, env=None):
278+
if env is None:
279+
env = Environment()
280+
281+
matches = env.detect_dataset(path)
282+
if not matches:
283+
raise Exception("Failed to detect dataset format automatically: "
284+
"no matching formats found")
285+
if 1 < len(matches):
286+
raise Exception("Failed to detect dataset format automatically:"
287+
" data matches more than one format: %s" % \
288+
', '.join(matches))
289+
return matches[0]

datumaro/components/environment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,22 @@ def register_model(self, name, model):
289289

290290
def unregister_model(self, name):
291291
self.models.unregister(name)
292+
293+
def is_format_known(self, name):
294+
return name in self.importers or name in self.extractors
295+
296+
def detect_dataset(self, path):
297+
matches = []
298+
299+
for format_name, importer in self.importers.items.items():
300+
log.debug("Checking '%s' format...", format_name)
301+
try:
302+
match = importer.detect(path)
303+
if match:
304+
log.debug("format matched")
305+
matches.append(format_name)
306+
except NotImplementedError:
307+
log.debug("Format '%s' does not support auto detection.",
308+
format_name)
309+
310+
return matches

0 commit comments

Comments
 (0)