forked from jasonppy/VoiceStar
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
82 lines (71 loc) · 3.11 KB
/
main.py
File metadata and controls
82 lines (71 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from pathlib import Path
import torch, os
from tqdm import tqdm
import pickle
import argparse
import logging, datetime
import torch.distributed as dist
from config import MyParser
from steps import trainer
from copy_codebase import copy_codebase
def world_info_from_env():
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
return local_rank, global_rank, world_size
if __name__ == "__main__":
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
torch.cuda.empty_cache()
args = MyParser().parse_args()
exp_dir = Path(args.exp_dir)
exp_dir.mkdir(exist_ok=True, parents=True)
logging.info(f"exp_dir: {str(exp_dir)}")
if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)):
if not os.path.exists("%s/bundle.pth" % args.exp_dir):
os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth")
resume = args.resume
assert(bool(args.exp_dir))
with open("%s/args.pkl" % args.exp_dir, "rb") as f:
old_args = pickle.load(f)
new_args = vars(args)
old_args = vars(old_args)
for key in new_args:
if key not in old_args or old_args[key] != new_args[key]:
old_args[key] = new_args[key]
args = argparse.Namespace(**old_args)
args.resume = resume
else:
args.resume = False
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
pickle.dump(args, f)
# make timeout longer (for generation)
timeout = datetime.timedelta(seconds=7200) # 60 minutes
if args.multinodes:
_local_rank, _, _ = world_info_from_env()
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout)
else:
dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout)
if args.local_wandb:
os.environ["WANDB_MODE"] = "offline"
rank = dist.get_rank()
if rank == 0:
logging.info(args)
logging.info(f"exp_dir: {str(exp_dir)}")
world_size = dist.get_world_size()
local_rank = int(_local_rank) if args.multinodes else rank
num_devices= torch.cuda.device_count()
logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}")
for device_idx in range(num_devices):
device_name = torch.cuda.get_device_name(device_idx)
logging.info(f"Device {device_idx}: {device_name}")
torch.cuda.set_device(local_rank)
if rank == 0:
user_dir = os.path.expanduser("~")
codebase_name = "VoiceStar"
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore"))
my_trainer = trainer.Trainer(args, world_size, rank, local_rank)
my_trainer.train()