Skip to content

Commit 2bcb11b

Browse files
authored
Update external_function.py
1 parent 5d340d7 commit 2bcb11b

File tree

1 file changed

+1
-122
lines changed

1 file changed

+1
-122
lines changed

models/external_function.py

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -149,127 +149,6 @@ def reduce_sum(x, axis=None, keepdim=False):
149149
return x
150150

151151

152-
class CorrelationLoss(nn.Module):
153-
r"""
154-
Perceptual loss, VGG-based
155-
https://arxiv.org/abs/1603.08155
156-
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
157-
"""
158-
159-
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
160-
super(CorrelationLoss, self).__init__()
161-
self.add_module('vgg', VGG19())
162-
self.criterion = torch.nn.L1Loss()
163-
self.weights = weights
164-
self.softmax = nn.Softmax(dim=-1)
165-
166-
def compute_gram(self, x):
167-
b, ch, h, w = x.size()
168-
f = x.view(b, ch, w * h)
169-
f_T = f.transpose(1, 2)
170-
G = f.bmm(f_T) / (h * w * ch)
171-
return G
172-
173-
def comput_correlation(self, x_vgg, y_vgg, source_vgg):
174-
input_shape = list(x_vgg.shape)
175-
176-
x_query = x_vgg.view(input_shape[0], -1, input_shape[2] * input_shape[3]).permute(0, 2,
177-
1) # B X N X C
178-
179-
y_query = y_vgg.view(input_shape[0], -1, input_shape[2] * input_shape[3]).permute(0, 2,
180-
1) # B X N X C
181-
182-
source_key = source_vgg.view(input_shape[0], -1, input_shape[2] * input_shape[3]).permute(0, 2,
183-
1) # B X N X C
184-
185-
escape_NaN = torch.FloatTensor([1e-4])
186-
escape_NaN = escape_NaN.cuda()
187-
188-
max_x = torch.max(torch.sqrt(reduce_sum(torch.pow(x_query, 2),
189-
axis=[2],
190-
keepdim=True)), escape_NaN)
191-
max_y = torch.max(torch.sqrt(reduce_sum(torch.pow(y_query, 2),
192-
axis=[2],
193-
keepdim=True)), escape_NaN)
194-
max_source = torch.max(torch.sqrt(reduce_sum(torch.pow(source_key, 2),
195-
axis=[2],
196-
keepdim=True)), escape_NaN)
197-
198-
x_query = x_query / max_x
199-
y_query = y_query / max_y
200-
source_key = source_key / max_source
201-
202-
source_key = source_key.permute(0, 2, 1)
203-
204-
x_energy = torch.bmm(x_query, source_key) # transpose check
205-
x_attention = self.softmax(x_energy * 10) # BX (N) X (N)
206-
207-
y_energy = torch.bmm(y_query, source_key) # transpose check
208-
y_attention = self.softmax(y_energy * 10) # BX (N) X (N)
209-
return x_attention, y_attention
210-
211-
def __call__(self, x, y, source):
212-
# Compute features
213-
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
214-
215-
content_loss = 0.0
216-
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
217-
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
218-
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
219-
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
220-
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
221-
222-
# Compute loss
223-
style_loss = 0.0
224-
style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
225-
style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
226-
style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
227-
style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
228-
229-
correlation_loss = 0
230-
231-
source_vgg = self.vgg(source)
232-
233-
x_attention_5, y_attention_5 = self.comput_correlation(x_vgg['relu5_1'], y_vgg['relu5_1'], source_vgg['relu5_1'])
234-
235-
correlation_loss += self.criterion(x_attention_5, y_attention_5)
236-
237-
x_attention_4, y_attention_4 = self.comput_correlation(x_vgg['relu4_1'], y_vgg['relu4_1'],
238-
source_vgg['relu4_1'])
239-
240-
correlation_loss += self.criterion(x_attention_4, y_attention_4)
241-
242-
return content_loss, style_loss, correlation_loss
243-
244-
245-
class DomainAdaptionLoss(nn.Module):
246-
r"""
247-
Perceptual loss, VGG-based
248-
https://arxiv.org/abs/1603.08155
249-
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
250-
"""
251-
252-
def __init__(self):
253-
super(DomainAdaptionLoss, self).__init__()
254-
self.criterion = torch.nn.L1Loss()
255-
256-
def compute_gram(self, x):
257-
b, ch, h, w = x.size()
258-
f = x.view(b, ch, w * h)
259-
f_T = f.transpose(1, 2)
260-
G = f.bmm(f_T) / (h * w * ch)
261-
return G
262-
263-
def __call__(self, x, y):
264-
# Compute features
265-
content_loss = self.criterion(x, y)
266-
267-
# Compute loss
268-
style_loss = self.criterion(self.compute_gram(x), self.compute_gram(y))
269-
270-
return content_loss, style_loss
271-
272-
273152
####################################################################################################
274153
# neural style transform loss from neural_style_tutorial of pytorch
275154
####################################################################################################
@@ -460,4 +339,4 @@ def forward(self, x):
460339
'relu5_3': relu5_3,
461340
'relu5_4': relu5_4,
462341
}
463-
return out
342+
return out

0 commit comments

Comments
 (0)