-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathexample_train.py
More file actions
27 lines (22 loc) · 1.19 KB
/
example_train.py
File metadata and controls
27 lines (22 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from argparse import ArgumentParser, Namespace
import torch
import sys
import litegs
import litegs.config
if __name__ == "__main__":
parser = ArgumentParser(description="Training script parameters")
lp_cdo,op_cdo,pp_cdo,dp_cdo=litegs.config.get_default_arg()
litegs.arguments.ModelParams.add_cmdline_arg(lp_cdo,parser)
litegs.arguments.OptimizationParams.add_cmdline_arg(op_cdo,parser)
litegs.arguments.PipelineParams.add_cmdline_arg(pp_cdo,parser)
litegs.arguments.DensifyParams.add_cmdline_arg(dp_cdo,parser)
parser.add_argument("--test_epochs", nargs="+", type=int, default=[])
parser.add_argument("--save_epochs", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_epochs", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(sys.argv[1:])
lp=litegs.arguments.ModelParams.extract(args)
op=litegs.arguments.OptimizationParams.extract(args)
pp=litegs.arguments.PipelineParams.extract(args)
dp=litegs.arguments.DensifyParams.extract(args)
litegs.training.start(lp,op,pp,dp,args.test_epochs,args.save_epochs,args.checkpoint_epochs,args.start_checkpoint)