3131
3232parser = argparse .ArgumentParser ()
3333parser .add_argument ('--n_epochs' , type = int , default = 200 , help = 'number of epochs of training' )
34- parser .add_argument ('--batch_size' , type = int , default = 4 , help = 'size of the batches' )
34+ parser .add_argument ('--batch_size' , type = int , default = 8 , help = 'size of the batches' )
3535parser .add_argument ('--dataset_name' , type = str , default = 'img_align_celeba' , help = 'name of the dataset' )
3636parser .add_argument ('--lr' , type = float , default = 0.0002 , help = 'adam: learning rate' )
3737parser .add_argument ('--b1' , type = float , default = 0.5 , help = 'adam: decay of first order momentum of gradient' )
3838parser .add_argument ('--b2' , type = float , default = 0.999 , help = 'adam: decay of first order momentum of gradient' )
39- parser .add_argument ('--n_cpu' , type = int , default = 8 , help = 'number of cpu threads to use during batch generation' )
39+ parser .add_argument ('--n_cpu' , type = int , default = 4 , help = 'number of cpu threads to use during batch generation' )
4040parser .add_argument ('--latent_dim' , type = int , default = 100 , help = 'dimensionality of the latent space' )
4141parser .add_argument ('--img_size' , type = int , default = 128 , help = 'size of each image dimension' )
4242parser .add_argument ('--mask_size' , type = int , default = 64 , help = 'size of random mask' )
@@ -83,7 +83,7 @@ def weights_init_normal(m):
8383 transforms .Normalize ((0.5 ,0.5 ,0.5 ), (0.5 ,0.5 ,0.5 )) ]
8484dataloader = DataLoader (ImageDataset ("../../data/%s" % opt .dataset_name , transforms_ = transforms_ ),
8585 batch_size = opt .batch_size , shuffle = True , num_workers = opt .n_cpu )
86- sample_dataloader = DataLoader (ImageDataset ("../../data/%s" % opt .dataset_name , transforms_ = transforms_ ),
86+ test_dataloader = DataLoader (ImageDataset ("../../data/%s" % opt .dataset_name , transforms_ = transforms_ , mode = 'val' ),
8787 batch_size = 12 , shuffle = True , num_workers = 1 )
8888
8989# Optimizers
@@ -92,35 +92,11 @@ def weights_init_normal(m):
9292
9393Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
9494
95- # Adversarial ground truths
96- valid = Variable (Tensor (np .ones (patch )), requires_grad = False )
97- fake = Variable (Tensor (np .zeros (patch )), requires_grad = False )
98-
99- def apply_random_mask (imgs ):
100- idx = np .random .randint (0 , opt .img_size - opt .mask_size , (imgs .shape [0 ], 2 ))
101-
102- masked_imgs = imgs .clone ()
103- masked_parts = None
104- for i , (y1 , x1 ) in enumerate (idx ):
105- y2 , x2 = y1 + opt .mask_size , x1 + opt .mask_size
106- masked_part = masked_imgs [i :i + 1 , :, y1 :y2 , x1 :x2 ].clone ()
107- masked_parts = masked_part if masked_parts is None else torch .cat ((masked_parts , masked_part ), 0 )
108- masked_imgs [i , :, y1 :y2 , x1 :x2 ] = 1
109-
110- return masked_imgs , masked_parts
111-
112- def apply_center_mask (imgs ):
113- # Get upper-left pixel coordinate
114- i = (imgs .shape [2 ] - opt .mask_size ) // 2
115-
116- masked_imgs = imgs .clone ()
117- masked_imgs [:, :, i :i + opt .mask_size , i :i + opt .mask_size ] = 1
118-
119- return masked_imgs , i
120-
12195def save_sample (batches_done ):
122- samples = Variable (next (iter (sample_dataloader )).type (Tensor ))
123- masked_samples , i = apply_center_mask (samples )
96+ samples , masked_samples , i = next (iter (test_dataloader ))
97+ samples = Variable (samples .type (Tensor ))
98+ masked_samples = Variable (masked_samples .type (Tensor ))
99+ i = i [0 ].item () # Upper-left coordinate of mask
124100 # Generate inpainted image
125101 gen_mask = generator (masked_samples )
126102 filled_samples = masked_samples .clone ()
@@ -134,9 +110,7 @@ def save_sample(batches_done):
134110# ----------
135111
136112for epoch in range (opt .n_epochs ):
137- for i , imgs in enumerate (dataloader ):
138-
139- masked_imgs , masked_parts = apply_random_mask (imgs )
113+ for i , (imgs , masked_imgs , masked_parts ) in enumerate (dataloader ):
140114
141115 # Adversarial ground truths
142116 valid = Variable (Tensor (imgs .shape [0 ], * patch ).fill_ (1.0 ), requires_grad = False )
0 commit comments