-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_bilinear.py
More file actions
107 lines (88 loc) · 4.25 KB
/
train_bilinear.py
File metadata and controls
107 lines (88 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
import H36M
import model
from util import config
from util.log import get_logger
logger, log_dir, comment = get_logger(comment=config.bilinear.comment)
if config.bilinear.comment is None or not os.path.exists('save/{comment}/parameter'.format(comment=comment)):
logger.info(' ')
logger.info(' ')
logger.info('===========================================================')
logger.info('Comment : ' + comment + ' ')
logger.info('===========================================================')
logger.info('Architecture : ' + 'Bilinear' + ' ')
logger.info(' -protocol : ' + config.bilinear.protocol + ' ')
logger.info(' -task : ' + H36M.Task.Train + ' ')
logger.info(' -device : ' + str(config.bilinear.device) + ' ')
logger.info('===========================================================')
logger.info('Data : ' + 'Human3.6M' + ' ')
logger.info(' -directory : ' + config.bilinear.data_dir + ' ')
logger.info(' -mini batch : ' + str(config.bilinear.batch_size) + ' ')
logger.info(' -shuffle : ' + 'True' + ' ')
logger.info(' -worker : ' + str(config.bilinear.num_workers) + ' ')
logger.info('===========================================================')
data = DataLoader(
H36M.Dataset(
data_dir=config.bilinear.data_dir,
task=H36M.Task.Train,
protocol=config.bilinear.protocol,
),
batch_size=config.bilinear.batch_size,
shuffle=True,
pin_memory=True,
num_workers=config.bilinear.num_workers,
)
bilinear, optimizer, step, train_epoch = model.bilinear.load(
device=config.bilinear.device,
parameter_dir='{log_dir}/parameter'.format(log_dir=log_dir) if config.bilinear.comment is not None else None,
)
criterion = nn.MSELoss()
writer = SummaryWriter(log_dir='{log_dir}/visualize'.format(
log_dir=log_dir,
))
bilinear.train()
for epoch in range(train_epoch + 1, train_epoch + 10 + 1):
with tqdm(total=len(data), desc='%d epoch' % epoch) as progress:
with torch.set_grad_enabled(True):
for subset, _, _, _ in data:
in_image_space = subset[H36M.Annotation.Part]
in_camera_space = subset[H36M.Annotation.S]
# Learning rate decay
if config.bilinear.lr_decay.activate and config.bilinear.lr_decay.condition(step):
lr = config.bilinear.lr_decay.function(step)
logger.info('Learning rate decay to {lr} (step: {step})'.format(lr=lr, step=step))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
in_image_space = in_image_space.to(config.bilinear.device)
in_camera_space = in_camera_space.to(config.bilinear.device)
optimizer.zero_grad()
prediction = bilinear(in_image_space)
loss = criterion(prediction, in_camera_space)
loss.backward()
nn.utils.clip_grad_norm_(bilinear.parameters(), max_norm=1)
optimizer.step()
# Too frequent update reduces the performance.
writer.add_scalar('BI/loss', loss, step)
progress.set_postfix(loss=float(loss.item()))
progress.update(1)
step = step + 1
save_dir = '{log_dir}/parameter'.format(log_dir=log_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_to = '{save_dir}/{epoch}.save'.format(save_dir=save_dir, epoch=epoch, )
torch.save(
{
'epoch': epoch,
'step': step,
'state': bilinear.state_dict(),
'optimizer': optimizer.state_dict(),
},
save_to,
)
logger.info('Epoch {epoch} saved (loss: {loss})'.format(epoch=epoch, loss=float(loss.item())))
writer.close()