Skip to content

Commit 3a00900

Browse files
Context Encoders: Cleaned up code. Sample in README
1 parent 0a96228 commit 3a00900

File tree

5 files changed

+48
-39
lines changed

5 files changed

+48
-39
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ $ cd implementations/context_encoder/
188188
$ python3 context_encoder.py
189189
```
190190

191+
<p align="center">
192+
<img src="assets/context_encoder.png" width="640"\>
193+
</p>
194+
<p align="center">
195+
Rows: Masked | Inpainted | Original | Masked | Inpainted | Original
196+
</p>
197+
191198
### Coupled GAN
192199
_Coupled Generative Adversarial Networks_
193200

assets/context_encoder.png

824 KB
Loading

implementations/context_encoder/context_encoder.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131

3232
parser = argparse.ArgumentParser()
3333
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
34-
parser.add_argument('--batch_size', type=int, default=4, help='size of the batches')
34+
parser.add_argument('--batch_size', type=int, default=8, help='size of the batches')
3535
parser.add_argument('--dataset_name', type=str, default='img_align_celeba', help='name of the dataset')
3636
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
3737
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
3838
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
39-
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
39+
parser.add_argument('--n_cpu', type=int, default=4, help='number of cpu threads to use during batch generation')
4040
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
4141
parser.add_argument('--img_size', type=int, default=128, help='size of each image dimension')
4242
parser.add_argument('--mask_size', type=int, default=64, help='size of random mask')
@@ -83,7 +83,7 @@ def weights_init_normal(m):
8383
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
8484
dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
8585
batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
86-
sample_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
86+
test_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'),
8787
batch_size=12, shuffle=True, num_workers=1)
8888

8989
# Optimizers
@@ -92,35 +92,11 @@ def weights_init_normal(m):
9292

9393
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
9494

95-
# Adversarial ground truths
96-
valid = Variable(Tensor(np.ones(patch)), requires_grad=False)
97-
fake = Variable(Tensor(np.zeros(patch)), requires_grad=False)
98-
99-
def apply_random_mask(imgs):
100-
idx = np.random.randint(0, opt.img_size-opt.mask_size, (imgs.shape[0], 2))
101-
102-
masked_imgs = imgs.clone()
103-
masked_parts = None
104-
for i, (y1, x1) in enumerate(idx):
105-
y2, x2 = y1 + opt.mask_size, x1 + opt.mask_size
106-
masked_part = masked_imgs[i:i+1, :, y1:y2, x1:x2].clone()
107-
masked_parts = masked_part if masked_parts is None else torch.cat((masked_parts, masked_part), 0)
108-
masked_imgs[i, :, y1:y2, x1:x2] = 1
109-
110-
return masked_imgs, masked_parts
111-
112-
def apply_center_mask(imgs):
113-
# Get upper-left pixel coordinate
114-
i = (imgs.shape[2] - opt.mask_size) // 2
115-
116-
masked_imgs = imgs.clone()
117-
masked_imgs[:, :, i:i+opt.mask_size, i:i+opt.mask_size] = 1
118-
119-
return masked_imgs, i
120-
12195
def save_sample(batches_done):
122-
samples = Variable(next(iter(sample_dataloader)).type(Tensor))
123-
masked_samples, i = apply_center_mask(samples)
96+
samples, masked_samples, i = next(iter(test_dataloader))
97+
samples = Variable(samples.type(Tensor))
98+
masked_samples = Variable(masked_samples.type(Tensor))
99+
i = i[0].item() # Upper-left coordinate of mask
124100
# Generate inpainted image
125101
gen_mask = generator(masked_samples)
126102
filled_samples = masked_samples.clone()
@@ -134,9 +110,7 @@ def save_sample(batches_done):
134110
# ----------
135111

136112
for epoch in range(opt.n_epochs):
137-
for i, imgs in enumerate(dataloader):
138-
139-
masked_imgs, masked_parts = apply_random_mask(imgs)
113+
for i, (imgs, masked_imgs, masked_parts) in enumerate(dataloader):
140114

141115
# Adversarial ground truths
142116
valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)

implementations/context_encoder/datasets.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,45 @@
88
import torchvision.transforms as transforms
99

1010
class ImageDataset(Dataset):
11-
def __init__(self, root, transforms_=None, mode='train'):
11+
def __init__(self, root, transforms_=None, img_size=128, mask_size=64, mode='train'):
1212
self.transform = transforms.Compose(transforms_)
13+
self.img_size = img_size
14+
self.mask_size = mask_size
15+
self.mode = mode
16+
self.files = sorted(glob.glob('%s/*.jpg' % root))
17+
self.files = self.files[:-4000] if mode == 'train' else self.files[-4000:]
1318

14-
self.files = sorted(glob.glob('%s/*.*' % root))
19+
def apply_random_mask(self, img):
20+
"""Randomly masks image"""
21+
y1, x1 = np.random.randint(0, self.img_size-self.mask_size, 2)
22+
y2, x2 = y1 + self.mask_size, x1 + self.mask_size
23+
masked_part = img[:, y1:y2, x1:x2]
24+
masked_img = img.clone()
25+
masked_img[:, y1:y2, x1:x2] = 1
26+
27+
return masked_img, masked_part
28+
29+
def apply_center_mask(self, img):
30+
"""Mask center part of image"""
31+
# Get upper-left pixel coordinate
32+
i = (self.img_size - self.mask_size) // 2
33+
masked_img = img.clone()
34+
masked_img[:, i:i+self.mask_size, i:i+self.mask_size] = 1
35+
36+
return masked_img, i
1537

1638
def __getitem__(self, index):
1739

1840
img = Image.open(self.files[index % len(self.files)])
1941
img = self.transform(img)
42+
if self.mode == 'train':
43+
# For training data perform random mask
44+
masked_img, aux = self.apply_random_mask(img)
45+
else:
46+
# For test data mask the center of the image
47+
masked_img, aux = self.apply_center_mask(img)
2048

21-
return img
49+
return img, masked_img, aux
2250

2351
def __len__(self):
2452
return len(self.files)

implementations/context_encoder/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ def downsample(in_feat, out_feat, normalize=True):
1111
layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
1212
if normalize:
1313
layers.append(nn.BatchNorm2d(out_feat, 0.8))
14-
layers.append(nn.LeakyReLU(0.2, inplace=True))
14+
layers.append(nn.LeakyReLU(0.2))
1515
return layers
1616

1717
def upsample(in_feat, out_feat, normalize=True):
1818
layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
1919
if normalize:
2020
layers.append(nn.BatchNorm2d(out_feat, 0.8))
21-
layers.append(nn.ReLU(inplace=True))
21+
layers.append(nn.ReLU())
2222
return layers
2323

2424
self.model = nn.Sequential(

0 commit comments

Comments
 (0)