-
Notifications
You must be signed in to change notification settings - Fork 152
Expand file tree
/
Copy pathpsp.py
More file actions
executable file
·141 lines (125 loc) · 6.26 KB
/
psp.py
File metadata and controls
executable file
·141 lines (125 loc) · 6.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
This file defines the core research contribution
"""
import math
import torch
from torch import nn
from models.stylegan2.model import Generator
from configs.paths_config import model_paths
from models.encoders import fpn_encoders, restyle_psp_encoders
from utils.model_utils import RESNET_MAPPING
class pSp(nn.Module):
def __init__(self, opts):
super(pSp, self).__init__()
self.set_opts(opts)
self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
def set_encoder(self):
if self.opts.encoder_type == 'GradualStyleEncoder':
encoder = fpn_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, self.opts)
elif self.opts.encoder_type == 'ResNetGradualStyleEncoder':
encoder = fpn_encoders.ResNetGradualStyleEncoder(self.n_styles, self.opts)
elif self.opts.encoder_type == 'BackboneEncoder':
encoder = restyle_psp_encoders.BackboneEncoder(50, 'ir_se', self.n_styles, self.opts)
elif self.opts.encoder_type == 'ResNetBackboneEncoder':
encoder = restyle_psp_encoders.ResNetBackboneEncoder(self.n_styles, self.opts)
else:
raise Exception(f'{self.opts.encoder_type} is not a valid encoders')
return encoder
def load_weights(self):
if self.opts.checkpoint_path is not None:
print(f'Loading ReStyle pSp from checkpoint: {self.opts.checkpoint_path}')
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False)
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
encoder_ckpt = self.__get_encoder_checkpoint()
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
def forward(self, x, latent=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
inject_latent=None, return_latents=False, alpha=None, average_code=False, input_is_full=False):
if input_code:
codes = x
else:
codes = self.encoder(x)
# residual step
if x.shape[1] == 6 and latent is not None:
# learn error with respect to previous iteration
codes = codes + latent
else:
# first iteration is with respect to the avg latent code
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
if latent_mask is not None:
for i in latent_mask:
if inject_latent is not None:
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
else:
codes[:, i] = 0
if average_code:
input_is_latent = True
else:
input_is_latent = (not input_code) or (input_is_full)
images, result_latent = self.decoder([codes],
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents)
if resize:
images = self.face_pool(images)
if return_latents:
return images, result_latent
else:
return images
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
def __get_encoder_checkpoint(self):
if "ffhq" in self.opts.dataset_type:
print('Loading encoders weights from irse50!')
encoder_ckpt = torch.load(model_paths['ir_se50'])
# Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder
if self.opts.input_nc != 3:
shape = encoder_ckpt['input_layer.0.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight']
encoder_ckpt['input_layer.0.weight'] = altered_input_layer
return encoder_ckpt
else:
print('Loading encoders weights from resnet34!')
encoder_ckpt = torch.load(model_paths['resnet34'])
# Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder
if self.opts.input_nc != 3:
shape = encoder_ckpt['conv1.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight']
encoder_ckpt['conv1.weight'] = altered_input_layer
mapped_encoder_ckpt = dict(encoder_ckpt)
for p, v in encoder_ckpt.items():
for original_name, psp_name in RESNET_MAPPING.items():
if original_name in p:
mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v
mapped_encoder_ckpt.pop(p)
return encoder_ckpt
@staticmethod
def __get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt