Skip to content

Commit f33efb9

Browse files
author
Maxim Zhiltsov
committed
t
1 parent f21f71d commit f33efb9

File tree

18 files changed

+2601
-801
lines changed

18 files changed

+2601
-801
lines changed

datumaro/components/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,14 @@ def __init__(self, default=None):
238238
self.__dict__['_default'] = default
239239

240240
def set(self, key, value):
241-
if key not in self.keys(allow_fallback=False):
242-
value = self._default(value)
243-
return super().set(key, value)
244-
else:
245-
return super().set(key, value)
241+
if self._default is not None:
242+
schema_entry_instance = self._default(value)
243+
if not isinstance(value, type(schema_entry_instance)):
244+
if isinstance(value, dict) and \
245+
isinstance(schema_entry_instance, Config):
246+
schema_entry_instance.update(value)
247+
value = schema_entry_instance
248+
else:
249+
raise Exception("Can not set key '%s' - schema mismatch" % (key))
250+
251+
return super().set(key, value)

datumaro/components/config_model.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,25 @@
77
DictConfig as _DictConfig, \
88
SchemaBuilder as _SchemaBuilder
99

10+
from datumaro.util import find
11+
12+
13+
REMOTE_SCHEMA = _SchemaBuilder() \
14+
.add('url', str) \
15+
.add('type', str) \
16+
.add('options', dict) \
17+
.build()
18+
19+
class Remote(Config):
20+
def __init__(self, config=None):
21+
super().__init__(config, schema=REMOTE_SCHEMA)
22+
1023

1124
SOURCE_SCHEMA = _SchemaBuilder() \
1225
.add('url', str) \
1326
.add('format', str) \
1427
.add('options', dict) \
28+
.add('remote', str) \
1529
.build()
1630

1731
class Source(Config):
@@ -29,35 +43,87 @@ def __init__(self, config=None):
2943
super().__init__(config, schema=MODEL_SCHEMA)
3044

3145

46+
BUILDSTAGE_SCHEMA = _SchemaBuilder() \
47+
.add('name', str) \
48+
.add('type', str) \
49+
.add('kind', str) \
50+
.add('params', dict) \
51+
.build()
52+
53+
class BuildStage(Config):
54+
def __init__(self, config=None):
55+
super().__init__(config, schema=BUILDSTAGE_SCHEMA)
56+
57+
BUILDTARGET_SCHEMA = _SchemaBuilder() \
58+
.add('stages', list) \
59+
.add('parents', list) \
60+
.build()
61+
62+
class BuildTarget(Config):
63+
def __init__(self, config=None):
64+
super().__init__(config, schema=BUILDTARGET_SCHEMA)
65+
self.stages = [BuildStage(o) for o in self.stages]
66+
67+
@property
68+
def root(self):
69+
return self.stages[0]
70+
71+
@property
72+
def head(self):
73+
return self.stages[-1]
74+
75+
def find_stage(self, stage):
76+
if stage == 'root':
77+
return self.root
78+
elif stage == 'head':
79+
return self.head
80+
return find(self.stages, lambda x: x.name == stage or x == stage)
81+
82+
def get_stage(self, stage):
83+
res = self.find_stage(stage)
84+
if res is None:
85+
raise KeyError("Unknown stage '%s'" % stage)
86+
return res
87+
88+
3289
PROJECT_SCHEMA = _SchemaBuilder() \
3390
.add('project_name', str) \
3491
.add('format_version', int) \
3592
\
36-
.add('subsets', list) \
37-
.add('sources', lambda: _DictConfig(
38-
lambda v=None: Source(v))) \
39-
.add('models', lambda: _DictConfig(
40-
lambda v=None: Model(v))) \
93+
.add('default_repo', str) \
94+
.add('remotes', lambda: _DictConfig(lambda v=None: Remote(v))) \
95+
.add('sources', lambda: _DictConfig(lambda v=None: Source(v))) \
96+
.add('models', lambda: _DictConfig(lambda v=None: Model(v))) \
97+
.add('build_targets', lambda: _DictConfig(lambda v=None: BuildTarget(v))) \
4198
\
4299
.add('models_dir', str, internal=True) \
43100
.add('plugins_dir', str, internal=True) \
44101
.add('sources_dir', str, internal=True) \
45102
.add('dataset_dir', str, internal=True) \
103+
.add('dvc_aux_dir', str, internal=True) \
104+
.add('pipelines_dir', str, internal=True) \
105+
.add('build_dir', str, internal=True) \
46106
.add('project_filename', str, internal=True) \
47107
.add('project_dir', str, internal=True) \
48108
.add('env_dir', str, internal=True) \
109+
.add('detached', bool, internal=True) \
49110
.build()
50111

51112
PROJECT_DEFAULT_CONFIG = Config({
52113
'project_name': 'undefined',
53-
'format_version': 1,
114+
'format_version': 2,
54115

55116
'sources_dir': 'sources',
56117
'dataset_dir': 'dataset',
57118
'models_dir': 'models',
58119
'plugins_dir': 'plugins',
120+
'dvc_aux_dir': 'dvc_aux',
121+
'pipelines_dir': 'dvc_pipelines',
122+
'build_dir': 'build',
59123

124+
'default_repo': 'origin',
60125
'project_filename': 'config.yaml',
61126
'project_dir': '',
62127
'env_dir': '.datumaro',
128+
'detached': False,
63129
}, mutable=False, schema=PROJECT_SCHEMA)

datumaro/components/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,9 @@ def __init__(self, source: IDataset = None,
452452

453453
self._format = DEFAULT_FORMAT
454454
self._source_path = None
455+
self._options = {}
455456

456-
def define_categories(self, categories: Dict):
457+
def define_categories(self, categories: CategoriesInfo):
457458
assert not self._data._categories and self._data._source is None
458459
self._data._categories = categories
459460

@@ -626,8 +627,7 @@ def import_from(cls, path: str, format: str = None, env: Environment = None,
626627
if format in env.importers:
627628
importer = env.make_importer(format)
628629
with logging_disabled(log.INFO):
629-
project = importer(path, **kwargs)
630-
detected_sources = list(project.config.sources.values())
630+
detected_sources = importer(path, **kwargs)
631631
elif format in env.extractors:
632632
detected_sources = [{
633633
'url': path, 'format': format, 'options': kwargs

datumaro/components/environment.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
from functools import partial
66
from glob import glob
7-
import git
87
import inspect
98
import logging as log
109
import os
1110
import os.path as osp
1211

1312
from datumaro.components.config import Config
14-
from datumaro.components.config_model import Model, Source
1513
from datumaro.util.os_util import import_foreign_module
1614

1715

@@ -46,29 +44,6 @@ def __getitem__(self, key):
4644
def __contains__(self, key):
4745
return key in self.items
4846

49-
50-
class ModelRegistry(Registry):
51-
def __init__(self, config=None):
52-
super().__init__(config, item_type=Model)
53-
54-
def load(self, config):
55-
# TODO: list default dir, insert values
56-
if 'models' in config:
57-
for name, model in config.models.items():
58-
self.register(name, model)
59-
60-
61-
class SourceRegistry(Registry):
62-
def __init__(self, config=None):
63-
super().__init__(config, item_type=Source)
64-
65-
def load(self, config):
66-
# TODO: list default dir, insert values
67-
if 'sources' in config:
68-
for name, source in config.sources.items():
69-
self.register(name, source)
70-
71-
7247
class PluginRegistry(Registry):
7348
def __init__(self, config=None, builtin=None, local=None):
7449
super().__init__(config)
@@ -85,47 +60,6 @@ def __init__(self, config=None, builtin=None, local=None):
8560
self.register(k, v)
8661

8762

88-
class GitWrapper:
89-
def __init__(self, config=None):
90-
self.repo = None
91-
92-
if config is not None and config.project_dir:
93-
self.init(config.project_dir)
94-
95-
@staticmethod
96-
def _git_dir(base_path):
97-
return osp.join(base_path, '.git')
98-
99-
@classmethod
100-
def spawn(cls, path):
101-
spawn = not osp.isdir(cls._git_dir(path))
102-
repo = git.Repo.init(path=path)
103-
if spawn:
104-
repo.config_writer().set_value("user", "name", "User") \
105-
.set_value("user", "email", "user@nowhere.com") \
106-
.release()
107-
# gitpython does not support init, use git directly
108-
repo.git.init()
109-
repo.git.commit('-m', 'Initial commit', '--allow-empty')
110-
return repo
111-
112-
def init(self, path):
113-
self.repo = self.spawn(path)
114-
return self.repo
115-
116-
def is_initialized(self):
117-
return self.repo is not None
118-
119-
def create_submodule(self, name, dst_dir, **kwargs):
120-
self.repo.create_submodule(name, dst_dir, **kwargs)
121-
122-
def has_submodule(self, name):
123-
return name in [submodule.name for submodule in self.repo.submodules]
124-
125-
def remove_submodule(self, name, **kwargs):
126-
return self.repo.submodule(name).remove(**kwargs)
127-
128-
12963
class Environment:
13064
_builtin_plugins = None
13165
PROJECT_EXTRACTOR_NAME = 'datumaro_project'
@@ -136,11 +70,6 @@ def __init__(self, config=None):
13670
config = Config(config,
13771
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
13872

139-
self.models = ModelRegistry(config)
140-
self.sources = SourceRegistry(config)
141-
142-
self.git = GitWrapper(config)
143-
14473
env_dir = osp.join(config.project_dir, config.env_dir)
14574
builtin = self._load_builtin_plugins()
14675
custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir))
@@ -284,12 +213,6 @@ def make_converter(self, name, *args, **kwargs):
284213
def make_transform(self, name, *args, **kwargs):
285214
return partial(self.transforms.get(name), *args, **kwargs)
286215

287-
def register_model(self, name, model):
288-
self.models.register(name, model)
289-
290-
def unregister_model(self, name):
291-
self.models.unregister(name)
292-
293216
def is_format_known(self, name):
294217
return name in self.importers or name in self.extractors
295218

datumaro/components/extractor.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -649,22 +649,18 @@ def find_sources(cls, path) -> List[Dict]:
649649
raise NotImplementedError()
650650

651651
def __call__(self, path, **extra_params):
652-
from datumaro.components.project import Project # cyclic import
653-
project = Project()
654-
655-
sources = self.find_sources(osp.normpath(path))
656-
if len(sources) == 0:
652+
found_sources = self.find_sources(osp.normpath(path))
653+
if len(found_sources) == 0:
657654
raise Exception("Failed to find dataset at '%s'" % path)
658655

659-
for desc in sources:
656+
sources = []
657+
for desc in found_sources:
660658
params = dict(extra_params)
661659
params.update(desc.get('options', {}))
662660
desc['options'] = params
661+
sources.append(desc)
663662

664-
source_name = osp.splitext(osp.basename(desc['url']))[0]
665-
project.add_source(source_name, desc)
666-
667-
return project
663+
return sources
668664

669665
@classmethod
670666
def _find_sources_recursive(cls, path, ext, extractor_name,

0 commit comments

Comments
 (0)