Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f59e82b

Browse files
Answerorantinucleon
authored andcommitted
Exactly reproduce 56 layers ResNet on CIFAR10 (#2046)
Exactly reproduce 56 layers ResNet on CIFAR10
1 parent 083eb16 commit f59e82b

File tree

3 files changed

+558
-0
lines changed

3 files changed

+558
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
'''
2+
Reproducing https://github.com/gcr/torch-residual-networks
3+
For image size of 32x32
4+
5+
References:
6+
7+
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
8+
'''
9+
import find_mxnet
10+
assert find_mxnet
11+
import mxnet as mx
12+
13+
14+
def get_conv(
15+
name,
16+
data,
17+
num_filter,
18+
kernel,
19+
stride,
20+
pad,
21+
with_relu,
22+
bn_momentum
23+
):
24+
conv = mx.symbol.Convolution(
25+
name=name,
26+
data=data,
27+
num_filter=num_filter,
28+
kernel=kernel,
29+
stride=stride,
30+
pad=pad,
31+
no_bias=True
32+
)
33+
bn = mx.symbol.BatchNorm(
34+
name=name + '_bn',
35+
data=conv,
36+
fix_gamma=False,
37+
momentum=bn_momentum,
38+
# Same with https://github.com/soumith/cudnn.torch/blob/master/BatchNormalization.lua
39+
eps=1e-5
40+
)
41+
return (
42+
# It's better to remove ReLU here
43+
# https://github.com/gcr/torch-residual-networks
44+
mx.symbol.Activation(name=name + '_relu', data=bn, act_type='relu')
45+
if with_relu else bn
46+
)
47+
48+
49+
def make_block(
50+
name,
51+
data,
52+
num_filter,
53+
dim_match,
54+
bn_momentum
55+
):
56+
conv1 = get_conv(
57+
name=name + '_conv1',
58+
data=data,
59+
num_filter=num_filter,
60+
kernel=(3, 3),
61+
stride=(1, 1) if dim_match else (2, 2),
62+
pad=(1, 1),
63+
with_relu=True,
64+
bn_momentum=bn_momentum
65+
)
66+
conv2 = get_conv(
67+
name=name + '_conv2',
68+
data=conv1,
69+
num_filter=num_filter,
70+
kernel=(3, 3),
71+
stride=(1, 1),
72+
pad=(1, 1),
73+
with_relu=False,
74+
bn_momentum=bn_momentum
75+
)
76+
if dim_match:
77+
shortcut = data
78+
else:
79+
# Like http://ethereon.github.io/netscope/#/gist/db945b393d40bfa26006
80+
# Test accuracy 0.922+ on CIFAR10 with 56 layers
81+
# shortcut = get_conv(
82+
# name=name + '_proj',
83+
# data=data,
84+
# num_filter=num_filter,
85+
# kernel=(1, 1),
86+
# stride=(2, 2),
87+
# pad=(0, 0),
88+
# with_relu=False,
89+
# bn_momentum=bn_momentum
90+
# )
91+
92+
# Type A shortcut
93+
# Note we use kernel (2, 2) rather than (1, 1) and a custom initializer
94+
# in train_cifar10_resnet.py
95+
# Test accuracy 0.918 on CIFAR10 with 56 layers and kernel (1, 1)
96+
# TODO(Answeror): Don't know why (1, 1) got much lower accuracy
97+
shortcut = mx.symbol.Convolution(
98+
name=name + '_proj',
99+
data=data,
100+
num_filter=num_filter,
101+
kernel=(2, 2),
102+
stride=(2, 2),
103+
pad=(0, 0),
104+
no_bias=True
105+
)
106+
107+
# Same with above, but ugly
108+
# Mxnet don't have nn.Padding as that in
109+
# https://github.com/gcr/torch-residual-networks/blob/master/residual-layers.lua
110+
# shortcut = mx.symbol.Pooling(
111+
# data=data,
112+
# name=name + '_pool',
113+
# kernel=(2, 2),
114+
# stride=(2, 2),
115+
# pool_type='avg'
116+
# )
117+
# shortcut = mx.symbol.Concat(
118+
# shortcut,
119+
# mx.symbol.minimum(shortcut + 1, 0),
120+
# num_args=2
121+
# )
122+
fused = shortcut + conv2
123+
return mx.symbol.Activation(
124+
name=name + '_relu',
125+
data=fused,
126+
act_type='relu'
127+
)
128+
129+
130+
def get_body(
131+
data,
132+
num_level,
133+
num_block,
134+
num_filter,
135+
bn_momentum
136+
):
137+
for level in range(num_level):
138+
for block in range(num_block):
139+
data = make_block(
140+
name='level%d_block%d' % (level + 1, block + 1),
141+
data=data,
142+
num_filter=num_filter * (2 ** level),
143+
dim_match=level == 0 or block > 0,
144+
bn_momentum=bn_momentum
145+
)
146+
return data
147+
148+
149+
def get_symbol(
150+
num_class,
151+
num_level=3,
152+
num_block=9,
153+
num_filter=16,
154+
bn_momentum=0.9,
155+
pool_kernel=(8, 8)
156+
):
157+
data = mx.symbol.Variable(name='data')
158+
# Simulate z-score normalization as that in
159+
# https://github.com/gcr/torch-residual-networks/blob/master/data/cifar-dataset.lua
160+
zscore = mx.symbol.BatchNorm(
161+
name='zscore',
162+
data=data,
163+
fix_gamma=True,
164+
momentum=bn_momentum
165+
)
166+
conv = get_conv(
167+
name='conv0',
168+
data=zscore,
169+
num_filter=num_filter,
170+
kernel=(3, 3),
171+
stride=(1, 1),
172+
pad=(1, 1),
173+
with_relu=True,
174+
bn_momentum=bn_momentum
175+
)
176+
body = get_body(
177+
conv,
178+
num_level,
179+
num_block,
180+
num_filter,
181+
bn_momentum
182+
)
183+
pool = mx.symbol.Pooling(data=body, kernel=pool_kernel, pool_type='avg')
184+
# The flatten layer seems superfluous
185+
flat = mx.symbol.Flatten(data=pool)
186+
fc = mx.symbol.FullyConnected(data=flat, num_hidden=num_class, name='fc')
187+
return mx.symbol.SoftmaxOutput(data=fc, name='softmax')

0 commit comments

Comments
 (0)