forked from andersbll/autoencoding_beyond_pixels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcond_vaegan.py
More file actions
179 lines (156 loc) · 6.37 KB
/
cond_vaegan.py
File metadata and controls
179 lines (156 loc) · 6.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from copy import deepcopy
import numpy as np
import cudarray as ca
import deeppy as dp
import deeppy.expr as expr
from vaegan import KLDivergence, NegativeGradient, ScaleGradient, SquareError, WeightedParameter
class AppendSpatially(expr.base.Binary):
def __call__(self, imgs, feats):
self.imgs = imgs
self.feats = feats
self.inputs = [imgs, feats]
return self
def setup(self):
b, c, h, w = self.imgs.out_shape
b_, f = self.feats.out_shape
if b != b_:
raise ValueError('batch size mismatch')
self.out_shape = (b, c+f, h, w)
self.out = ca.empty(self.out_shape)
self.out_grad = ca.empty(self.out_shape)
self.tmp = ca.zeros((b, f, h, w))
def fprop(self):
self.tmp.fill(0.0)
feats = ca.reshape(self.feats.out, self.feats.out.shape + (1, 1))
ca.add(feats, self.tmp, out=self.tmp)
ca.extra.concatenate(self.imgs.out, self.tmp, axis=1, out=self.out)
def bprop(self):
ca.extra.split(self.out_grad, a_size=self.imgs.out_shape[1], axis=1,
out_a=self.imgs.out_grad, out_b=self.tmp)
class ConditionalSequential(expr.Sequential):
def __call__(self, x, y):
for op in self.collection:
if isinstance(op, (expr.Concatenate, AppendSpatially)):
x = op(x, y)
else:
x = op(x)
return x
class ConditionalVAEGAN(dp.base.Model, dp.base.CollectionMixin):
def __init__(self, encoder, sampler, generator, discriminator, mode,
reconstruct_error=None, vae_grad_scale=1.0):
self.encoder = encoder
self.sampler = sampler
self.mode = mode
self.discriminator = discriminator
self.vae_grad_scale = vae_grad_scale
self.eps = 1e-4
if reconstruct_error is None:
reconstruct_error = SquareError()
self.reconstruct_error = reconstruct_error
generator.params = [p.parent if isinstance(p, WeightedParameter) else p
for p in generator.params]
if self.mode == 'vaegan':
generator.params = [WeightedParameter(p, vae_grad_scale)
for p in generator.params]
self.generator_neg = deepcopy(generator)
self.generator_neg.params = [p.share() for p in generator.params]
if self.mode == 'gan':
generator.params = [WeightedParameter(p, -1.0)
for p in generator.params]
self.generator = generator
self.collection = [self.encoder, self.sampler, self.generator, self.discriminator]
def _embed_expr(self, x, y):
h_enc = self.encoder(x, y)
z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc)
z = z_mu
return z
def _reconstruct_expr(self, z, y):
return self.generator(z, y)
def setup(self, x_shape, y_shape):
batch_size = x_shape[0]
self.sampler.batch_size = x_shape[0]
self.x_src = expr.Source(x_shape)
self.y_src = expr.Source(y_shape)
if self.mode in ['vae', 'vaegan']:
h_enc = self.encoder(self.x_src, self.y_src)
z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc)
self.kld = KLDivergence()(z_mu, z_log_sigma)
x_tilde = self.generator(z, self.y_src)
self.logpxz = self.reconstruct_error(x_tilde, self.x_src)
loss = 0.5*self.kld + expr.sum(self.logpxz)
if self.mode in ['gan', 'vaegan']:
y = self.y_src
if self.mode == 'gan':
z = self.sampler.samples()
x_tilde = self.generator(z, y)
gen_size = batch_size
elif self.mode == 'vaegan':
z = ScaleGradient(0.0)(z)
z = expr.Concatenate(axis=0)(z, z_eps)
y = expr.Concatenate(axis=0)(y, self.y_src)
x_tilde = self.generator_neg(z, y)
gen_size = batch_size*2
x = expr.Concatenate(axis=0)(self.x_src, x_tilde)
y = expr.Concatenate(axis=0)(y, self.y_src)
d = self.discriminator(x, y)
d = expr.clip(d, self.eps, 1.0-self.eps)
real_size = batch_size
sign = np.ones((real_size + gen_size, 1), dtype=ca.float_)
sign[real_size:] = -1.0
offset = np.zeros_like(sign)
offset[real_size:] = 1.0
self.gan_loss = expr.log(d*sign + offset)
if self.mode == 'gan':
loss = expr.sum(-self.gan_loss)
elif self.mode == 'vaegan':
loss = loss + expr.sum(-self.gan_loss)
self._graph = expr.ExprGraph(loss)
self._graph.out_grad = ca.array(1.0)
self._graph.setup()
@property
def params(self):
enc_params = []
gen_params = self.generator.params
dis_params = []
if self.mode != 'vae':
dis_params = self.discriminator.params
if self.mode != 'gan':
enc_params = self.encoder.params + self.sampler.params
return enc_params, gen_params, dis_params
def update(self, x, y):
self.x_src.out = x
self.y_src.out = y
self._graph.fprop()
self._graph.bprop()
kld = 0
d_x_loss = 0
d_z_loss = 0
if self.mode != 'gan':
kld = np.array(self.kld.out)
if self.mode != 'vae':
gan_loss = -np.array(self.gan_loss.out)
batch_size = x.shape[0]
d_x_loss = float(np.mean(gan_loss[:batch_size]))
d_z_loss = float(np.mean(gan_loss[batch_size:]))
return d_x_loss, d_z_loss, kld
def _batchwise(self, x, y, expr_fun):
x = dp.input.Input.from_any(x)
y = dp.input.Input.from_any(y)
x_src = expr.Source(x.x_shape)
y_src = expr.Source(y.x_shape)
graph = expr.ExprGraph(expr_fun(x_src, y_src))
graph.setup()
out = []
for x_batch, y_batch in zip(x.batches(), y.batches()):
x_src.out = x_batch['x']
y_src.out = y_batch['x']
graph.fprop()
out.append(np.array(graph.out))
out = np.concatenate(out)[:x.n_samples]
return out
def embed(self, x, y):
""" Input to hidden. """
return self._batchwise(x, y, self._embed_expr)
def reconstruct(self, z, y):
""" Hidden to input. """
return self._batchwise(z, y, self._reconstruct_expr)