-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathclf_trainer.py
More file actions
109 lines (88 loc) · 3.55 KB
/
clf_trainer.py
File metadata and controls
109 lines (88 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from os.path import join
import argparse
import numpy as np
from tqdm import tqdm, trange
from scipy.io import savemat
import torch
import torch.backends.cudnn as cudnn
import classifier
import utils
import config
np.random.seed(67)
torch.manual_seed(67)
def main(args):
train_loader, test_loader, n_classes = config.create_loaders(args)
clf = classifier.ConvNet(n_classes,
args.init_lr,
args.momentum,
args.weight_decay,
args.device,
args.exp_dir,
'',
args.model,
args.multi_gpu)
test_accs = np.zeros([args.n_epochs], 'float32')
train_accs = np.zeros([args.n_epochs], 'float32')
train_losses = np.zeros([args.n_epochs], 'float32')
learning_rates = np.zeros([args.n_epochs], 'float32')
if args.tensorboard:
from tensorboard_logger import log_value
try:
for epoch in trange(args.n_epochs, ncols=100):
# arrange the learning rate
lr = clf.adjust_learning_rate(epoch, args.lr_dec_rate, args.lr_dec_int)
########## Classifier update ##########
train_loss, train_acc = clf.train_epoch(train_loader)
########## Evaluation ##########
test_acc = clf.test(test_loader)
if args.tensorboard:
log_value('test_acc', test_acc, epoch)
log_value('train_acc', train_acc, epoch)
log_value('train_loss', train_loss, epoch)
log_value('learning_rate', lr, epoch)
else:
tqdm.write('train_acc:{:.1f}, test_acc{:.1f}'.format(
train_acc, test_acc))
test_accs[epoch] = test_acc
train_accs[epoch] = train_acc
train_losses[epoch] = train_loss
learning_rates[epoch] = lr
torch.save(
{ 'epoch': epoch, 'clf': clf.net.state_dict() },
join(args.exp_dir, 'state_dict.ckpt'))
except KeyboardInterrupt:
print ('Early exit')
import ipdb; ipdb.set_trace()
savemat(
join(args.exp_dir, 'logs.mat'),
{'train_accs':train_accs, 'test_accs':test_accs, 'train_losses':train_losses, 'learning_rates': learning_rates})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', type=str, default='./log')
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--dataset', type=str, default='cifar10',
choices=['cifar10', 'cifar100'])
parser.add_argument('--model', type=str, default='resnet-10',
choices=['resnet-10', 'resnet-18', 'resnet-34', 'resnet-50'])
parser.add_argument('--device', type=str, default='cuda',
choices=['cpu', 'cuda'])
parser.add_argument('--multi_gpu', type=bool, default=True)
parser.add_argument('--init_lr', type=float, default=0.1)
parser.add_argument('--lr_dec_rate', type=float, default=0.1)
parser.add_argument('--lr_dec_int', type=int, default=100)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--momentum', type=float, default=0.9 )
parser.add_argument('--tensorboard', type=bool, default=True)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=128)
parser.add_argument('--n_epochs', type=int, default=300)
FLAGS = parser.parse_args()
if FLAGS.device == 'cuda' and torch.cuda.is_available():
cudnn.benchmark = True
else:
FLAGS.device = 'cpu'
utils.write_logs(FLAGS)
if FLAGS.tensorboard:
from tensorboard_logger import configure
configure(FLAGS.exp_dir)
main(FLAGS)