@@ -149,127 +149,6 @@ def reduce_sum(x, axis=None, keepdim=False):
149149 return x
150150
151151
152- class CorrelationLoss (nn .Module ):
153- r"""
154- Perceptual loss, VGG-based
155- https://arxiv.org/abs/1603.08155
156- https://github.com/dxyang/StyleTransfer/blob/master/utils.py
157- """
158-
159- def __init__ (self , weights = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]):
160- super (CorrelationLoss , self ).__init__ ()
161- self .add_module ('vgg' , VGG19 ())
162- self .criterion = torch .nn .L1Loss ()
163- self .weights = weights
164- self .softmax = nn .Softmax (dim = - 1 )
165-
166- def compute_gram (self , x ):
167- b , ch , h , w = x .size ()
168- f = x .view (b , ch , w * h )
169- f_T = f .transpose (1 , 2 )
170- G = f .bmm (f_T ) / (h * w * ch )
171- return G
172-
173- def comput_correlation (self , x_vgg , y_vgg , source_vgg ):
174- input_shape = list (x_vgg .shape )
175-
176- x_query = x_vgg .view (input_shape [0 ], - 1 , input_shape [2 ] * input_shape [3 ]).permute (0 , 2 ,
177- 1 ) # B X N X C
178-
179- y_query = y_vgg .view (input_shape [0 ], - 1 , input_shape [2 ] * input_shape [3 ]).permute (0 , 2 ,
180- 1 ) # B X N X C
181-
182- source_key = source_vgg .view (input_shape [0 ], - 1 , input_shape [2 ] * input_shape [3 ]).permute (0 , 2 ,
183- 1 ) # B X N X C
184-
185- escape_NaN = torch .FloatTensor ([1e-4 ])
186- escape_NaN = escape_NaN .cuda ()
187-
188- max_x = torch .max (torch .sqrt (reduce_sum (torch .pow (x_query , 2 ),
189- axis = [2 ],
190- keepdim = True )), escape_NaN )
191- max_y = torch .max (torch .sqrt (reduce_sum (torch .pow (y_query , 2 ),
192- axis = [2 ],
193- keepdim = True )), escape_NaN )
194- max_source = torch .max (torch .sqrt (reduce_sum (torch .pow (source_key , 2 ),
195- axis = [2 ],
196- keepdim = True )), escape_NaN )
197-
198- x_query = x_query / max_x
199- y_query = y_query / max_y
200- source_key = source_key / max_source
201-
202- source_key = source_key .permute (0 , 2 , 1 )
203-
204- x_energy = torch .bmm (x_query , source_key ) # transpose check
205- x_attention = self .softmax (x_energy * 10 ) # BX (N) X (N)
206-
207- y_energy = torch .bmm (y_query , source_key ) # transpose check
208- y_attention = self .softmax (y_energy * 10 ) # BX (N) X (N)
209- return x_attention , y_attention
210-
211- def __call__ (self , x , y , source ):
212- # Compute features
213- x_vgg , y_vgg = self .vgg (x ), self .vgg (y )
214-
215- content_loss = 0.0
216- content_loss += self .weights [0 ] * self .criterion (x_vgg ['relu1_1' ], y_vgg ['relu1_1' ])
217- content_loss += self .weights [1 ] * self .criterion (x_vgg ['relu2_1' ], y_vgg ['relu2_1' ])
218- content_loss += self .weights [2 ] * self .criterion (x_vgg ['relu3_1' ], y_vgg ['relu3_1' ])
219- content_loss += self .weights [3 ] * self .criterion (x_vgg ['relu4_1' ], y_vgg ['relu4_1' ])
220- content_loss += self .weights [4 ] * self .criterion (x_vgg ['relu5_1' ], y_vgg ['relu5_1' ])
221-
222- # Compute loss
223- style_loss = 0.0
224- style_loss += self .criterion (self .compute_gram (x_vgg ['relu2_2' ]), self .compute_gram (y_vgg ['relu2_2' ]))
225- style_loss += self .criterion (self .compute_gram (x_vgg ['relu3_4' ]), self .compute_gram (y_vgg ['relu3_4' ]))
226- style_loss += self .criterion (self .compute_gram (x_vgg ['relu4_4' ]), self .compute_gram (y_vgg ['relu4_4' ]))
227- style_loss += self .criterion (self .compute_gram (x_vgg ['relu5_2' ]), self .compute_gram (y_vgg ['relu5_2' ]))
228-
229- correlation_loss = 0
230-
231- source_vgg = self .vgg (source )
232-
233- x_attention_5 , y_attention_5 = self .comput_correlation (x_vgg ['relu5_1' ], y_vgg ['relu5_1' ], source_vgg ['relu5_1' ])
234-
235- correlation_loss += self .criterion (x_attention_5 , y_attention_5 )
236-
237- x_attention_4 , y_attention_4 = self .comput_correlation (x_vgg ['relu4_1' ], y_vgg ['relu4_1' ],
238- source_vgg ['relu4_1' ])
239-
240- correlation_loss += self .criterion (x_attention_4 , y_attention_4 )
241-
242- return content_loss , style_loss , correlation_loss
243-
244-
245- class DomainAdaptionLoss (nn .Module ):
246- r"""
247- Perceptual loss, VGG-based
248- https://arxiv.org/abs/1603.08155
249- https://github.com/dxyang/StyleTransfer/blob/master/utils.py
250- """
251-
252- def __init__ (self ):
253- super (DomainAdaptionLoss , self ).__init__ ()
254- self .criterion = torch .nn .L1Loss ()
255-
256- def compute_gram (self , x ):
257- b , ch , h , w = x .size ()
258- f = x .view (b , ch , w * h )
259- f_T = f .transpose (1 , 2 )
260- G = f .bmm (f_T ) / (h * w * ch )
261- return G
262-
263- def __call__ (self , x , y ):
264- # Compute features
265- content_loss = self .criterion (x , y )
266-
267- # Compute loss
268- style_loss = self .criterion (self .compute_gram (x ), self .compute_gram (y ))
269-
270- return content_loss , style_loss
271-
272-
273152####################################################################################################
274153# neural style transform loss from neural_style_tutorial of pytorch
275154####################################################################################################
@@ -460,4 +339,4 @@ def forward(self, x):
460339 'relu5_3' : relu5_3 ,
461340 'relu5_4' : relu5_4 ,
462341 }
463- return out
342+ return out
0 commit comments