-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
318 lines (284 loc) · 10.6 KB
/
main.py
File metadata and controls
318 lines (284 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import json
import time
import glob
import wandb
import random
import argparse
import numpy as np
from data_loader.data_loaders import get_loader
from trainer.trainer import Trainer
from trainer.tester import Tester
from trainer.inferencer import Inferencer
from logger.logger import create_logger
import torch
import torch.backends.cudnn as cudnn
from model import get_model
import model.metric as module_metric
from utils.optimizer import get_optimizer
from utils.lr_scheduler import get_scheduler
from utils.utils import init_wandb_run, prepare_device
from config import get_config
def parse_option():
parser = argparse.ArgumentParser(
"VM-ASR training, evaluation, and inference script", add_help=False
)
parser.add_argument(
"--cfg",
type=str,
required=True,
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs="+",
)
# easy config modification
parser.add_argument("--batch-size", type=int, help="batch size for single GPU")
parser.add_argument(
"--input_sr",
type=int,
help="the input sample rate (if set, the random resample will be disabled)",
)
parser.add_argument(
"--target_sr",
type=int,
help="the target sample rate",
)
parser.add_argument("--resume", type=str, help="path to checkpoint for models")
parser.add_argument(
"--accumulation-steps", type=int, help="gradient accumulation steps"
)
parser.add_argument(
"--disable_amp", action="store_true", help="Disable pytorch amp"
)
parser.add_argument(
"--output",
default="logs",
type=str,
metavar="PATH",
help="root of output folder, the full path is <output>/<model_name>/<tag> (default: output)",
)
parser.add_argument(
"--tag",
default=time.strftime("%Y%m%d%H%M%S", time.localtime()),
help="tag of experiment",
)
parser.add_argument("--eval", action="store_true", help="Perform evaluation only")
parser.add_argument(
"--inference", action="store_true", help="Perform inference only"
)
parser.add_argument(
"--input", type=str, help="Input file or directory for inference"
)
# TODO: Add throughput mode
parser.add_argument(
"--throughput", action="store_true", help="Test throughput only"
)
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
def main(config):
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
# Setup device with single/multiple GPUs
device, device_ids = prepare_device(config.N_GPU)
# Get the model
models = get_model(config)
# Set the models to device
models = {k: v.to(device) for k, v in models.items()}
# Get the metrics
metrics = [getattr(module_metric, met) for met in config.TRAIN.METRICS]
if config.WANDB.ENABLE and not config.EVAL_MODE and not config.INFERENCE_MODE:
init_wandb_run(config)
# Inference mode
if config.INFERENCE_MODE:
logger.info(f"Starting inference...")
logger.info(f"Loading checkpoint from {config.MODEL.RESUME_PATH}")
# Make sure we only have the generator model for inference
models = {"generator": models["generator"]}
# Create the inferencer
inferencer = Inferencer(
models=models,
metric_ftns=metrics,
config=config,
device=(device, device_ids),
logger=logger,
)
# Check if input is specified
if args.input is None:
logger.error("Input path must be specified for inference mode")
return
# Run inference on the specified input
if os.path.isfile(args.input):
inferencer.infer_file(args.input)
elif os.path.isdir(args.input):
inferencer.infer_directory(args.input)
else:
logger.error(f"Input path does not exist: {args.input}")
logger.info("Inference completed successfully")
return
# Evaluation mode
elif config.EVAL_MODE:
logger.info(f"Starting evaluation ...")
logger.info(f"Loading checkpoint from {config.MODEL.RESUME_PATH}")
# Remove models except the generator
models = {"generator": models["generator"]}
data_loader_test = get_loader(config, logger)
logger.info(f"TESTING: ({len(data_loader_test)} files)")
# Test the trained model
tester = Tester(
models=models,
metric_ftns=metrics,
config=config,
device=(device, device_ids),
data_loader=data_loader_test,
logger=logger,
)
tester.evaluate()
# Training mode
else:
data_loader_train, data_loader_val = get_loader(config, logger)
# Initialize the optimizer and learning rate scheduler
optimizers = {"generator": None, "discriminator": None}
lr_schedulers = {"generator": None, "discriminator": None}
# Get the optimizer and lr_scheduler for the generator
optimizers["generator"] = get_optimizer(config, models["generator"], logger)
lr_schedulers["generator"] = (
get_scheduler(
config,
optimizers["generator"],
len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS,
)
if config.TRAIN.ACCUMULATION_STEPS > 1
else get_scheduler(config, optimizers["generator"], len(data_loader_train))
)
# Get the optimizer and lr_scheduler for the discriminator
if config.TRAIN.ADVERSARIAL.ENABLE:
if config.TRAIN.ADVERSARIAL.DISCRIMINATORS is not None:
# There would be more than one discriminators if specified
# models["discriminator"] is a dict with keys as discriminator names, we need to iterate over them
optimizers["discriminator"] = get_optimizer(
config,
[
models[disc_name]
for disc_name in config.TRAIN.ADVERSARIAL.DISCRIMINATORS
],
logger,
)
lr_schedulers["discriminator"] = (
get_scheduler(
config,
optimizers["discriminator"],
len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS,
)
if config.TRAIN.ACCUMULATION_STEPS > 1
else get_scheduler(
config, optimizers["discriminator"], len(data_loader_train)
)
)
else:
# Log an error if the discriminator is not specified
logger.error(
"Adversarial training is enabled but the discriminator is not specified in the config file. Please specify the discriminator in the config file."
)
trainer = Trainer(
models=models,
metric_ftns=metrics,
optimizers=optimizers,
config=config,
device=(device, device_ids),
data_loader_train=data_loader_train,
data_loader_val=data_loader_val,
lr_schedulers=lr_schedulers,
amp=config.AMP_ENABLE,
gan=config.TRAIN.ADVERSARIAL.ENABLE,
logger=logger,
)
trainer.train()
if config.WANDB.ENABLE:
wandb.finish()
def validate_resume_path(config):
assert os.path.exists(
config.MODEL.RESUME_PATH
), f"Folder not found, please check the path: {config.MODEL.RESUME_PATH}"
if config.EVAL_MODE or config.INFERENCE_MODE:
# There must be a checkpoint for evaluation or inference
assert (
glob.glob(os.path.join(config.MODEL.RESUME_PATH, "*.pth")) != []
), f"No checkpoint found in the folder. Please check the path: {config.MODEL.RESUME_PATH}"
def setup_test(config):
# Evaluate the trained model with the test dataset
assert (
len(config.TAG.split("_")) == 2
), "TAG should be in format {input_sr}_{target_sr}"
input_sr, target_sr = config.TAG.split("_")
# Example: "./results/16k_DeciData_MPD_WGAN_Local/16000/2000"
output_dir = os.path.join(
config.TEST.RESULTS_DIR,
os.path.basename(config.MODEL.RESUME_PATH),
target_sr,
input_sr,
)
# Remove the existing output directory
if os.path.exists(output_dir):
os.system(f"rm -rf {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# Update config
config.defrost()
config.OUTPUT = output_dir
config.freeze()
logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}")
return config, logger
def setup_inference(config):
# Setup for inference mode
assert (
len(config.TAG.split("_")) == 2
), "TAG should be in format {input_sr}_{target_sr}"
input_sr, target_sr = config.TAG.split("_")
# Create inference output directory
output_dir = os.path.join(
config.INFERENCE.RESULTS_DIR,
os.path.basename(config.MODEL.RESUME_PATH),
target_sr,
input_sr,
)
os.makedirs(output_dir, exist_ok=True)
# Update config
config.defrost()
config.OUTPUT = output_dir
config.freeze()
logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}")
return config, logger
if __name__ == "__main__":
args, config = parse_option()
# Create output folder
os.makedirs(config.OUTPUT, exist_ok=True)
os.makedirs(config.DEBUG_OUTPUT, exist_ok=True)
# Set the random seed for reproducibility
seed = config.SEED
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
if config.MODEL.RESUME_PATH is not None:
validate_resume_path(config)
if config.INFERENCE_MODE:
config, logger = setup_inference(config)
elif config.EVAL_MODE:
config, logger = setup_test(config)
else:
logger = create_logger(
output_dir=config.MODEL.RESUME_PATH,
name=f"{config.MODEL.NAME}",
load_existing=True,
)
logger.info(f"Resume training from {config.MODEL.RESUME_PATH}")
else:
logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}")
logger.info(config.dump())
logger.info(json.dumps(vars(args)))
main(config)