|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import os.path as op |
| 6 | +import yaml |
| 7 | +from yacs.config import CfgNode as CN |
| 8 | + |
| 9 | +from comm import comm |
| 10 | + |
| 11 | + |
| 12 | +_C = CN() |
| 13 | + |
| 14 | +_C.BASE = [''] |
| 15 | +_C.NAME = '' |
| 16 | +_C.DATA_DIR = '' |
| 17 | +_C.DIST_BACKEND = 'nccl' |
| 18 | +_C.GPUS = (0,) |
| 19 | +# _C.LOG_DIR = '' |
| 20 | +_C.MULTIPROCESSING_DISTRIBUTED = True |
| 21 | +_C.OUTPUT_DIR = '' |
| 22 | +_C.PIN_MEMORY = True |
| 23 | +_C.PRINT_FREQ = 20 |
| 24 | +_C.RANK = 0 |
| 25 | +_C.VERBOSE = True |
| 26 | +_C.WORKERS = 4 |
| 27 | +_C.MODEL_SUMMARY = False |
| 28 | + |
| 29 | +_C.AMP = CN() |
| 30 | +_C.AMP.ENABLED = False |
| 31 | +_C.AMP.MEMORY_FORMAT = 'nchw' |
| 32 | + |
| 33 | +# Cudnn related params |
| 34 | +_C.CUDNN = CN() |
| 35 | +_C.CUDNN.BENCHMARK = True |
| 36 | +_C.CUDNN.DETERMINISTIC = False |
| 37 | +_C.CUDNN.ENABLED = True |
| 38 | + |
| 39 | +# common params for NETWORK |
| 40 | +_C.MODEL = CN() |
| 41 | +_C.MODEL.NAME = 'cls_hrnet' |
| 42 | +_C.MODEL.INIT_WEIGHTS = True |
| 43 | +_C.MODEL.PRETRAINED = '' |
| 44 | +_C.MODEL.PRETRAINED_LAYERS = ['*'] |
| 45 | +_C.MODEL.NUM_CLASSES = 1000 |
| 46 | +_C.MODEL.SPEC = CN(new_allowed=True) |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | +_C.LOSS = CN(new_allowed=True) |
| 53 | +_C.LOSS.LABEL_SMOOTHING = 0.0 |
| 54 | +_C.LOSS.LOSS = 'softmax' |
| 55 | + |
| 56 | +# DATASET related params |
| 57 | +_C.DATASET = CN() |
| 58 | +_C.DATASET.ROOT = '' |
| 59 | +_C.DATASET.DATASET = 'imagenet' |
| 60 | +_C.DATASET.TRAIN_SET = 'train' |
| 61 | +_C.DATASET.TEST_SET = 'val' |
| 62 | +_C.DATASET.DATA_FORMAT = 'jpg' |
| 63 | +_C.DATASET.LABELMAP = '' |
| 64 | +_C.DATASET.TRAIN_TSV_LIST = [] |
| 65 | +_C.DATASET.TEST_TSV_LIST = [] |
| 66 | +_C.DATASET.SAMPLER = 'default' |
| 67 | + |
| 68 | +_C.DATASET.TARGET_SIZE = -1 |
| 69 | + |
| 70 | +# training data augmentation |
| 71 | +_C.INPUT = CN() |
| 72 | +_C.INPUT.MEAN = [0.485, 0.456, 0.406] |
| 73 | +_C.INPUT.STD = [0.229, 0.224, 0.225] |
| 74 | + |
| 75 | +# data augmentation |
| 76 | +_C.AUG = CN() |
| 77 | +_C.AUG.SCALE = (0.08, 1.0) |
| 78 | +_C.AUG.RATIO = (3.0/4.0, 4.0/3.0) |
| 79 | +_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0] |
| 80 | +_C.AUG.GRAY_SCALE = 0.0 |
| 81 | +_C.AUG.GAUSSIAN_BLUR = 0.0 |
| 82 | +_C.AUG.DROPBLOCK_LAYERS = [3, 4] |
| 83 | +_C.AUG.DROPBLOCK_KEEP_PROB = 1.0 |
| 84 | +_C.AUG.DROPBLOCK_BLOCK_SIZE = 7 |
| 85 | +_C.AUG.MIXUP_PROB = 0.0 |
| 86 | +_C.AUG.MIXUP = 0.0 |
| 87 | +_C.AUG.MIXCUT = 0.0 |
| 88 | +_C.AUG.MIXCUT_MINMAX = [] |
| 89 | +_C.AUG.MIXUP_SWITCH_PROB = 0.5 |
| 90 | +_C.AUG.MIXUP_MODE = 'batch' |
| 91 | +_C.AUG.MIXCUT_AND_MIXUP = False |
| 92 | +_C.AUG.INTERPOLATION = 2 |
| 93 | +_C.AUG.TIMM_AUG = CN(new_allowed=True) |
| 94 | +_C.AUG.TIMM_AUG.USE_LOADER = False |
| 95 | +_C.AUG.TIMM_AUG.USE_TRANSFORM = False |
| 96 | + |
| 97 | +# train |
| 98 | +_C.TRAIN = CN() |
| 99 | + |
| 100 | +_C.TRAIN.AUTO_RESUME = True |
| 101 | +_C.TRAIN.CHECKPOINT = '' |
| 102 | +_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True) |
| 103 | +_C.TRAIN.SCALE_LR = True |
| 104 | +_C.TRAIN.LR = 0.001 |
| 105 | + |
| 106 | +_C.TRAIN.OPTIMIZER = 'sgd' |
| 107 | +_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True) |
| 108 | +_C.TRAIN.MOMENTUM = 0.9 |
| 109 | +_C.TRAIN.WD = 0.0001 |
| 110 | +_C.TRAIN.WITHOUT_WD_LIST = [] |
| 111 | +_C.TRAIN.NESTEROV = True |
| 112 | +# for adam |
| 113 | +_C.TRAIN.GAMMA1 = 0.99 |
| 114 | +_C.TRAIN.GAMMA2 = 0.0 |
| 115 | + |
| 116 | +_C.TRAIN.BEGIN_EPOCH = 0 |
| 117 | +_C.TRAIN.END_EPOCH = 100 |
| 118 | + |
| 119 | +_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256 |
| 120 | +_C.TRAIN.BATCH_SIZE_PER_GPU = 32 |
| 121 | +_C.TRAIN.SHUFFLE = True |
| 122 | + |
| 123 | +_C.TRAIN.EVAL_BEGIN_EPOCH = 0 |
| 124 | + |
| 125 | +_C.TRAIN.DETECT_ANOMALY = False |
| 126 | + |
| 127 | +_C.TRAIN.CLIP_GRAD_NORM = 0.0 |
| 128 | +_C.TRAIN.SAVE_ALL_MODELS = False |
| 129 | + |
| 130 | +# testing |
| 131 | +_C.TEST = CN() |
| 132 | + |
| 133 | +# size of images for each device |
| 134 | +_C.TEST.BATCH_SIZE_PER_GPU = 32 |
| 135 | +_C.TEST.CENTER_CROP = True |
| 136 | +_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256 |
| 137 | +_C.TEST.INTERPOLATION = 2 |
| 138 | +_C.TEST.MODEL_FILE = '' |
| 139 | +_C.TEST.REAL_LABELS = False |
| 140 | +_C.TEST.VALID_LABELS = '' |
| 141 | + |
| 142 | +_C.FINETUNE = CN() |
| 143 | +_C.FINETUNE.FINETUNE = False |
| 144 | +_C.FINETUNE.USE_TRAIN_AUG = False |
| 145 | +_C.FINETUNE.BASE_LR = 0.003 |
| 146 | +_C.FINETUNE.BATCH_SIZE = 512 |
| 147 | +_C.FINETUNE.EVAL_EVERY = 3000 |
| 148 | +_C.FINETUNE.TRAIN_MODE = True |
| 149 | +# _C.FINETUNE.MODEL_FILE = '' |
| 150 | +_C.FINETUNE.FROZEN_LAYERS = [] |
| 151 | +_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True) |
| 152 | +_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step' |
| 153 | + |
| 154 | +# debug |
| 155 | +_C.DEBUG = CN() |
| 156 | +_C.DEBUG.DEBUG = False |
| 157 | + |
| 158 | + |
| 159 | +def _update_config_from_file(config, cfg_file): |
| 160 | + config.defrost() |
| 161 | + with open(cfg_file, 'r') as f: |
| 162 | + yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) |
| 163 | + |
| 164 | + for cfg in yaml_cfg.setdefault('BASE', ['']): |
| 165 | + if cfg: |
| 166 | + _update_config_from_file( |
| 167 | + config, op.join(op.dirname(cfg_file), cfg) |
| 168 | + ) |
| 169 | + print('=> merge config from {}'.format(cfg_file)) |
| 170 | + config.merge_from_file(cfg_file) |
| 171 | + config.freeze() |
| 172 | + |
| 173 | + |
| 174 | +def update_config(config, args): |
| 175 | + _update_config_from_file(config, args.cfg) |
| 176 | + |
| 177 | + config.defrost() |
| 178 | + config.merge_from_list(args.opts) |
| 179 | + if config.TRAIN.SCALE_LR: |
| 180 | + config.TRAIN.LR *= comm.world_size |
| 181 | + file_name, _ = op.splitext(op.basename(args.cfg)) |
| 182 | + config.NAME = file_name + config.NAME |
| 183 | + config.RANK = comm.rank |
| 184 | + |
| 185 | + if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD: |
| 186 | + config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH |
| 187 | + |
| 188 | + if 'timm' == config.TRAIN.OPTIMIZER: |
| 189 | + config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR |
| 190 | + |
| 191 | + aug = config.AUG |
| 192 | + if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX: |
| 193 | + aug.MIXUP_PROB = 1.0 |
| 194 | + config.freeze() |
| 195 | + |
| 196 | + |
| 197 | +def save_config(cfg, path): |
| 198 | + if comm.is_main_process(): |
| 199 | + with open(path, 'w') as f: |
| 200 | + f.write(cfg.dump()) |
| 201 | + |
| 202 | + |
| 203 | +if __name__ == '__main__': |
| 204 | + import sys |
| 205 | + with open(sys.argv[1], 'w') as f: |
| 206 | + print(_C, file=f) |
| 207 | + |
0 commit comments