|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +from util import util |
| 4 | +import torch |
| 5 | +import data |
| 6 | +import models |
| 7 | + |
| 8 | +class BaseOptions(): |
| 9 | + def __init__(self): |
| 10 | + self.parser = argparse.ArgumentParser() |
| 11 | + self.initialized = False |
| 12 | + |
| 13 | + def initialize(self): |
| 14 | + # experiment specifics |
| 15 | + self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') |
| 16 | + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') |
| 17 | + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') |
| 18 | + self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use') |
| 19 | + self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') |
| 20 | + self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') |
| 21 | + self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") |
| 22 | + self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') |
| 23 | + self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP') |
| 24 | + self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') |
| 25 | + |
| 26 | + # input/output sizes |
| 27 | + self.parser.add_argument('--image_nc', type=int, default=3) |
| 28 | + self.parser.add_argument('--pose_nc', type=int, default=18) |
| 29 | + self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') |
| 30 | + self.parser.add_argument('--old_size', type=int, default=(256, 176), help='Scale images to this size. The final image will be cropped to --crop_size.') |
| 31 | + self.parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') |
| 32 | + self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') |
| 33 | + self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels') |
| 34 | + self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') |
| 35 | + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') |
| 36 | + |
| 37 | + # for setting inputs |
| 38 | + self.parser.add_argument('--dataset_mode', type=str, default='fashion') |
| 39 | + self.parser.add_argument('--dataroot', type=str, default='/media/data2/zhangpz/DataSet/Fashion') |
| 40 | + self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') |
| 41 | + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') |
| 42 | + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') |
| 43 | + self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') |
| 44 | + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') |
| 45 | + |
| 46 | + # for displays |
| 47 | + self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') |
| 48 | + self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') |
| 49 | + self.parser.add_argument('--display_id', type=int, default=0, help='display id of the web') # 1 |
| 50 | + self.parser.add_argument('--display_port', type=int, default=8096, help='visidom port of the web display') |
| 51 | + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, |
| 52 | + help='if positive, display all images in a single visidom web panel') |
| 53 | + self.parser.add_argument('--display_env', type=str, default=self.parser.parse_known_args()[0].name.replace('_', ''), |
| 54 | + help='the environment of visidom display') |
| 55 | + # for instance-wise features |
| 56 | + self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') |
| 57 | + self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') |
| 58 | + self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') |
| 59 | + self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') |
| 60 | + self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') |
| 61 | + self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') |
| 62 | + self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') |
| 63 | + self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') |
| 64 | + |
| 65 | + self.initialized = True |
| 66 | + |
| 67 | + def parse(self, save=True): |
| 68 | + if not self.initialized: |
| 69 | + self.initialize() |
| 70 | + opt, _ = self.parser.parse_known_args() |
| 71 | + # modify the options for different models |
| 72 | + model_option_set = models.get_option_setter(opt.model) |
| 73 | + self.parser = model_option_set(self.parser, self.isTrain) |
| 74 | + |
| 75 | + data_option_set = data.get_option_setter(opt.dataset_mode) |
| 76 | + self.parser = data_option_set(self.parser, self.isTrain) |
| 77 | + |
| 78 | + self.opt = self.parser.parse_args() |
| 79 | + self.opt.isTrain = self.isTrain # train or test |
| 80 | + |
| 81 | + if torch.cuda.is_available(): |
| 82 | + self.opt.device = torch.device("cuda") |
| 83 | + torch.backends.cudnn.benchmark = True # cudnn auto-tuner |
| 84 | + else: |
| 85 | + self.opt.device = torch.device("cpu") |
| 86 | + |
| 87 | + str_ids = self.opt.gpu_ids.split(',') |
| 88 | + self.opt.gpu_ids = [] |
| 89 | + for str_id in str_ids: |
| 90 | + id = int(str_id) |
| 91 | + if id >= 0: |
| 92 | + self.opt.gpu_ids.append(id) |
| 93 | + |
| 94 | + # set gpu ids |
| 95 | + if len(self.opt.gpu_ids) > 0: |
| 96 | + torch.cuda.set_device(self.opt.gpu_ids[0]) |
| 97 | + |
| 98 | + args = vars(self.opt) |
| 99 | + |
| 100 | + print('------------ Options -------------') |
| 101 | + for k, v in sorted(args.items()): |
| 102 | + print('%s: %s' % (str(k), str(v))) |
| 103 | + print('-------------- End ----------------') |
| 104 | + |
| 105 | + # save to the disk |
| 106 | + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) |
| 107 | + util.mkdirs(expr_dir) |
| 108 | + if save and not (self.isTrain and self.opt.continue_train): |
| 109 | + name = 'train' if self.isTrain else 'test' |
| 110 | + file_name = os.path.join(expr_dir, name+'_opt.txt') |
| 111 | + with open(file_name, 'wt') as opt_file: |
| 112 | + opt_file.write('------------ Options -------------\n') |
| 113 | + for k, v in sorted(args.items()): |
| 114 | + opt_file.write('%s: %s\n' % (str(k), str(v))) |
| 115 | + opt_file.write('-------------- End ----------------\n') |
| 116 | + return self.opt |
0 commit comments