-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathGAN.jl
More file actions
52 lines (45 loc) · 1.52 KB
/
GAN.jl
File metadata and controls
52 lines (45 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# This file contains the functions needed to define the gans
using LinearAlgebra
function crossEntTrue(x)
return max(x, 0) - x + log(1 + exp(-abs(x)))
end
function crossEntFalse(x)
return max(x, 0) + log(1 + exp(-abs(x)))
end
# creates a random orthogonal matrix for initialization
function rand_orth(m, n)
@assert(m>=n)
A = randn(m,n)
return Matrix(qr(A).Q)
end
function ReLU(x)
return max(x,zero(x))
end
# This function applies a dense net of specified depth, size, and nonlinearity
# with weights given by the first argument to the input data.
function denseNet(weights, data, inDim, outDim, depth, n_hidden, nonlin)
offs = 1
res = nonlin.(reshape(weights[offs : (offs + inDim * n_hidden - 1)],
(n_hidden, inDim)) * data)
offs += inDim * n_hidden
for k = 1 : depth
res = nonlin.(reshape(weights[offs : (offs + n_hidden^2 - 1)],
(n_hidden, n_hidden)) * res)
offs += n_hidden^2
end
return reshape(weights[offs : (offs + outDim * n_hidden - 1)], (outDim, n_hidden)) * res
end
# Creates a vector of suitable size to be the weights of a dense net and
# initializes each layer as a random orthonormal matrix.
function init_denseNet(inDim, outDim, depth, n_hidden, scale=0.8)
if n_hidden >= inDim
out = (scale * rand_orth(n_hidden, inDim))[:]
else
out = (scale * rand_orth(inDim, n_hidden)')[:]
end
for k = 1 : depth
append!(out, scale * rand_orth(n_hidden, n_hidden)[:])
end
append!(out, (scale * rand_orth(n_hidden,outDim)')[:])
return out
end