Skip to content
Prev Previous commit
Next Next commit
update configs
  • Loading branch information
gaoyang07 committed Jan 11, 2023
commit 2ae2175be0ff328b0e72cdc595421e0d64501137
49 changes: 49 additions & 0 deletions configs/_base_/settings/cifar10_bs96_nsga.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'mmcls.CIFAR10'
data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor',
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False)

train_pipeline = [
dict(type='mmcls.RandomCrop', crop_size=32, padding=4),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(type='mmcls.Cutout', shape=16, pad_val=0, prob=1.0),
dict(type='mmcls.PackClsInputs'),
]

test_pipeline = [
dict(type='mmcls.PackClsInputs'),
]

train_dataloader = dict(
batch_size=96,
num_workers=5,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
test_mode=False,
pipeline=train_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=96,
num_workers=5,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10/',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_base_ = [
'mmcls::_base_/default_runtime.py',
'mmcls::_base_/schedules/imagenet_bs2048.py',
'mmrazor::_base_/settings/cifar10_bs96_nsga.py',
'mmrazor::_base_/nas_backbones/nsga_mobilenetv3_supernet.py',
]

supernet = dict(
_scope_='mmrazor',
type='SearchableImageClassifier',
backbone=_base_.nas_backbone,
neck=dict(type='SqueezeMeanPoolingWithDropout', drop_ratio=0.2),
head=dict(
type='DynamicLinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(
type='mmcls.LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5)),
input_resizer_cfg=_base_.input_resizer_cfg,
connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'),
)

model = dict(
_scope_='mmrazor',
type='NSGANetV2',
architecture=supernet,
data_preprocessor=_base_.data_preprocessor,
mutators=dict(
channel_mutator=dict(
type='mmrazor.OneShotChannelMutator',
channel_unit_cfg={
'type': 'OneShotMutableChannelUnit',
'default_args': {
'unit_predefined': True
}
},
parse_cfg={'type': 'Predefined'}),
value_mutator=dict(type='DynamicValueMutator')))

find_unused_parameters = True

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook', interval=1, max_keep_ckpts=1, save_best='auto'))
93 changes: 58 additions & 35 deletions mmrazor/models/algorithms/nas/nsganetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,50 @@
from mmrazor.registry import MODELS
from mmrazor.utils import ValidFixMutable
from ..base import BaseAlgorithm, LossResults
from ..space_mixin import SpaceMixin

VALID_MUTATOR_TYPE = Union[BaseMutator, Dict]
VALID_MUTATORS_TYPE = Dict[str, Union[BaseMutator, Dict]]
VALID_DISTILLER_TYPE = Union[ConfigurableDistiller, Dict]


@MODELS.register_module()
class NSGANetV2(BaseAlgorithm):
"""NSGANetV2 algorithm."""
class NSGANetV2(BaseAlgorithm, SpaceMixin):
"""Implementation of `NSGANetV2 <https://arxiv.org/abs/2007.10396>`_

NSGANetV2 generates task-specific models that are competitive under
multiple competing objectives.

NSGANetV2 comprises of two surrogates, one at the architecture level to
improve sample efficiency and one at the weights level, through a supernet,
to improve gradient descent training efficiency.

The logic of the search part is implemented in
:class:`mmrazor.engine.NSGA2SearchLoop`

Args:
architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel`
or built model. Corresponding to supernet in NAS algorithm.
mutators (VALID_MUTATORS_TYPE): Configs to build different mutators.
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix_subnet is no needed.

loaded dict or built :obj:`FixSubnet`. Defaults to None.
data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process
config of :class:`BaseDataPreprocessor`. Defaults to None.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.2.
backbone_dropout_stages (List): Stages to be set dropout. Defaults to
[6, 7].
norm_training (bool): Whether to set norm layers to training mode,
namely, not freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to False.
init_cfg (Optional[dict]): Init config for ``BaseModule``.
Defaults to None.

Note:
NSGANetV2 uses two mutators which are ``DynamicValueMutator`` and
``ChannelMutator``. `DynamicValueMutator` handle the mutable object
``OneShotMutableValue`` while ChannelMutator handle the mutable object
``OneShotMutableChannel``.
"""

def __init__(self,
architecture: Union[BaseModel, Dict],
Expand All @@ -33,47 +68,35 @@ def __init__(self,
init_cfg: Optional[dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)

if isinstance(mutators, dict):
built_mutators: Dict = dict()
for name, mutator_cfg in mutators.items():
if 'parse_cfg' in mutator_cfg and isinstance(
mutator_cfg['parse_cfg'], dict):
assert mutator_cfg['parse_cfg'][
'type'] == 'Predefined', \
'BigNAS only support predefined.'
mutator: BaseMutator = MODELS.build(mutator_cfg)
built_mutators[name] = mutator
mutator.prepare_from_supernet(self.architecture)
self.mutators = built_mutators
else:
raise TypeError('mutator should be a `dict` but got '
f'{type(mutators)}')

self.drop_path_rate = drop_path_rate
self.backbone_dropout_stages = backbone_dropout_stages
self.norm_training = norm_training
self.is_supernet = True

if fix_subnet:
# Avoid circular import
from mmrazor.structures import load_fix_subnet

# According to fix_subnet, delete the unchosen part of supernet
load_fix_subnet(self, fix_subnet)
self.is_supernet = False
else:
if isinstance(mutators, dict):
built_mutators: Dict = dict()
for name, mutator_cfg in mutators.items():
if 'parse_cfg' in mutator_cfg and isinstance(
mutator_cfg['parse_cfg'], dict):
assert mutator_cfg['parse_cfg'][
'type'] == 'Predefined', \
'NSGANetV2 only support predefined.'
mutator: BaseMutator = MODELS.build(mutator_cfg)
built_mutators[name] = mutator
mutator.prepare_from_supernet(self.architecture)
self.mutators = built_mutators
else:
raise TypeError('mutator should be a `dict` but got '
f'{type(mutators)}')
self._build_search_space()
self.is_supernet = True

def sample_subnet(self, kind='random') -> Dict:
"""Random sample subnet by mutator."""
subnet = dict()
for mutator in self.mutators.values():
subnet.update(mutator.sample_choices(kind))
return subnet

def set_subnet(self, subnet: Dict[str, Dict[int, Union[int,
list]]]) -> None:
"""Set the subnet sampled by :meth:sample_subnet."""
for mutator in self.mutators.values():
mutator.set_choices(subnet)
self.drop_path_rate = drop_path_rate
self.backbone_dropout_stages = backbone_dropout_stages
self.norm_training = norm_training

def loss(
self,
Expand Down