Skip to content

Commit e2eefdf

Browse files
committed
[TOPI][Testing] Enable conv2d NHWC fp16 topi testing for arm_cpu
This commit adds fp16 test cases to the conv2d NHWC TOPI schedules for `arm_cpu`. Following the example of apache#8529, the numpy reference conv2d output is computed in fp32 instead of fp16, while the absolute tolerance varies for each test case according to the size of the summed axis and the output's largest element.
1 parent f044eef commit e2eefdf

1 file changed

Lines changed: 32 additions & 7 deletions

File tree

tests/python/topi/test_topi_conv2d_nhwc.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack,
5454
),
5555
(
56-
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a",
56+
"llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16",
5757
topi.arm_cpu.compute_conv2d_NHWC_hybrid,
5858
topi.arm_cpu.schedule_conv2d_NHWC_hybrid,
5959
),
@@ -64,7 +64,7 @@
6464
),
6565
)
6666

67-
dtype = tvm.testing.parameter("float32")
67+
dtype = tvm.testing.parameter("float16", "float32")
6868

6969
batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters(
7070
# Pad M, N, K
@@ -104,14 +104,36 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd
104104
a_shape = (batch, in_height, in_width, in_channel)
105105
w_shape = (kernel, kernel, in_channel, num_filter)
106106

107+
np.random.seed(0)
107108
a_np = np.random.uniform(size=a_shape).astype(dtype)
108109
w_np = np.random.uniform(size=w_shape).astype(dtype)
109110
dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
110-
b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
111+
112+
# scipy.signal.convolve2d does not support float16 data types,
113+
# and the python fallback would be too slow for general use.
114+
conv_dtype = "float32" if dtype == "float16" else dtype
115+
b_np = tvm.topi.testing.conv2d_nhwc_python(
116+
a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding
117+
).astype(dtype)
111118
return a_np, w_np, b_np
112119

113120

114-
def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilation):
121+
def get_tolerance(dtype, w_np, b_np):
122+
if dtype == "float16":
123+
# A summation in float16 with a single accumulator very
124+
# quickly runs into large rounding errors.
125+
# This tolerance is necessary to ensure no false negatives,
126+
# but it may introduce false positives, depending on schedule behaviour.
127+
num_values_summed = w_np.shape[0] * w_np.shape[1] * w_np.shape[2]
128+
next_float_gap_size = np.nextafter(b_np.max(), np.inf, dtype=b_np.dtype) - b_np.max()
129+
tol = {"rtol": 1e-5, "atol": num_values_summed * next_float_gap_size / 2}
130+
else:
131+
tol = {"rtol": 1e-5, "atol": 1e-7}
132+
133+
return tol
134+
135+
136+
def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation):
115137
a_np, w_np, b_np = ref_data
116138

117139
A = te.placeholder(a_np.shape, name="A", dtype=dtype)
@@ -137,7 +159,8 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio
137159
return
138160

139161
func(a, w, b)
140-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
162+
tol = get_tolerance(dtype, w_np, b_np)
163+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
141164

142165

143166
def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation):
@@ -155,7 +178,8 @@ def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilatio
155178
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
156179
func = tvm.build(s, [A, W, B], target)
157180
func(a, w, b)
158-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
181+
tol = get_tolerance(dtype, w_np, b_np)
182+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
159183

160184

161185
def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
@@ -184,7 +208,8 @@ def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
184208
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
185209
func = tvm.build(s, [A, W, B], target)
186210
func(a, w, b)
187-
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
211+
tol = get_tolerance(dtype, w_np_hwio, b_np)
212+
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])
188213

189214

190215
if __name__ == "__main__":

0 commit comments

Comments
 (0)