Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions python/tvm/relay/op/strategy/adreno.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,18 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
or (data_layout == "NCHW" and kernel_layout == "OIHW4o")
):
if len(kernel.shape) == 4:
_, _, kh, kw = get_const_tuple(kernel.shape)
oc, _, kh, kw = get_const_tuple(kernel.shape)
else:
_, _, kh, kw, _ = get_const_tuple(kernel.shape)
oc, _, kh, kw, _ = get_const_tuple(kernel.shape)
# We cannot use textures for case than number of channels is less than 4.
# So, we use compute functions from cuda.
if len(kernel.shape) == 4 and oc < 4:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda",
)
return strategy
if (
(2 < kh < 8 and 2 < kw < 8 and kh == kw)
and (stride_h == 1 and stride_w == 1)
Expand All @@ -69,9 +78,18 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
or (data_layout == "NHWC" and kernel_layout == "HWIO4o")
):
if len(kernel.shape) == 4:
kh, kw, _, _ = get_const_tuple(kernel.shape)
kh, kw, _, oc = get_const_tuple(kernel.shape)
else:
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
kh, kw, _, oc, _ = get_const_tuple(kernel.shape)
# We cannot use textures for case than number of channels is less than 4.
# So, we use compute functions from cuda.
if len(kernel.shape) == 4 and oc < 4:
strategy.add_implementation(
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)
return strategy
if (
(2 < kh < 8 and 2 < kw < 8 and kh == kw)
and (stride_h == 1 and stride_w == 1)
Expand Down Expand Up @@ -125,12 +143,21 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target):
if (data_layout == "NCHW" and kernel_layout == "OIHW") or (
data_layout == "NCHW4c" and kernel_layout == "OIHW4o"
):
strategy.add_implementation(
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc),
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc),
name="depthwise_conv2d_nchwc.image2d",
plevel=10,
)
# We cannot use textures for case than number of channels is less than 4.
# So, we use compute functions from cuda.
if len(kernel.shape) == 4 and oc < 4:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc),
wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc),
name="depthwise_conv2d_nchwc.image2d",
plevel=10,
)
elif (data_layout == "NHWC" and kernel_layout == "HWOI") or (
data_layout == "NHWC4c" and kernel_layout == "HWOI4o"
):
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,5 +1289,35 @@ def test_injective_nwo_inputs2(remote, target, dtype):
)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_conv2d_to_3_channels(remote, target, dtype):
input_shape = (1, 256, 200, 200)
filter_shape = (3, 256, 1, 1)
A = relay.var("data", shape=input_shape, dtype=dtype)
B = relay.var("weight", shape=filter_shape, dtype=dtype)

D = relay.nn.conv2d(
A,
B,
data_layout="NCHW",
kernel_layout="OIHW",
padding=[0, 0, 0, 0],
out_dtype=dtype,
channels=3,
kernel_size=(1, 1),
)
mod = relay.Function([A, B], D)
np.random.seed(0)
initializer = relay.testing.init.Xavier()
filter_data = np.zeros(filter_shape).astype(dtype)
initializer("weight", filter_data)
params1 = {
"weight": tvm.nd.array(filter_data),
}

build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [])


if __name__ == "__main__":
tvm.testing.main()
30 changes: 30 additions & 0 deletions tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,5 +737,35 @@ def test_conv2d_winograd_non_rect(remote, target, dtype):
assert len(matches) > 0


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_conv2d_to_3_channels(remote, target, dtype):
input_shape = (1, 200, 200, 256)
filter_shape = (1, 1, 256, 3)
A = relay.var("data", shape=input_shape, dtype=dtype)
B = relay.var("weight", shape=filter_shape, dtype=dtype)

D = relay.nn.conv2d(
A,
B,
data_layout="NHWC",
kernel_layout="HWIO",
padding=[0, 0, 0, 0],
out_dtype=dtype,
channels=3,
kernel_size=(1, 1),
)
mod = relay.Function([A, B], D)
np.random.seed(0)
initializer = relay.testing.init.Xavier()
filter_data = np.zeros(filter_shape).astype(dtype)
initializer("weight", filter_data)
params1 = {
"weight": tvm.nd.array(filter_data),
}

build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [])


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,36 @@ def test_depthwise_conv2d_repack_bias_nchw(remote, target, dtype):
build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_conv2d_to_3_channels(remote, target, dtype):
input_shape = (1, 3, 200, 200)
filter_shape = (3, 1, 1, 1)
A = relay.var("data", shape=input_shape, dtype=dtype)
B = relay.var("weight", shape=filter_shape, dtype=dtype)

D = relay.nn.conv2d(
A,
B,
data_layout="NCHW",
kernel_layout="OIHW",
padding=[0, 0, 0, 0],
out_dtype=dtype,
channels=3,
groups=3,
kernel_size=(1, 1),
)
mod = relay.Function([A, B], D)
np.random.seed(0)
initializer = relay.testing.init.Xavier()
filter_data = np.zeros(filter_shape).astype(dtype)
initializer("weight", filter_data)
params1 = {
"weight": tvm.nd.array(filter_data),
}

build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [])


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,36 @@ def test_depthwise_conv2d_1_513_513_3x3_3_3_1(remote, target, dtype):
build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_conv2d_to_3_channels(remote, target, dtype):
input_shape = (1, 200, 200, 3)
filter_shape = (1, 1, 3, 1)
A = relay.var("data", shape=input_shape, dtype=dtype)
B = relay.var("weight", shape=filter_shape, dtype=dtype)

D = relay.nn.conv2d(
A,
B,
data_layout="NHWC",
kernel_layout="HWOI",
padding=[0, 0, 0, 0],
out_dtype=dtype,
channels=3,
groups=3,
kernel_size=(1, 1),
)
mod = relay.Function([A, B], D)
np.random.seed(0)
initializer = relay.testing.init.Xavier()
filter_data = np.zeros(filter_shape).astype(dtype)
initializer("weight", filter_data)
params1 = {
"weight": tvm.nd.array(filter_data),
}

build_run_compare(remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [])


if __name__ == "__main__":
tvm.testing.main()
56 changes: 38 additions & 18 deletions tests/python/unittest/test_te_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,22 +471,42 @@ def _check(B, A=A):
_check(te.compute((10,), lambda i: A[i]))


def test_bound_block():
def _check(shape, expected, block_size=4):
N, C, H, W = shape
tail = C % block_size
chunks = C // block_size
if tail != 0:
chunks += 1
A = te.placeholder((N, C, H, W), name="A")
pad_value = tvm.tir.const(0, A.dtype)

def _reorder_data_nchw(*indices):
condition = []
condition.append(indices[1] == chunks - 1)
condition.append(indices[4] >= tail)
condition = tvm.tir.all(*condition)
return tvm.tir.if_then_else(
condition,
pad_value,
A[indices[0], indices[1] * block_size + indices[4], indices[2], indices[3]],
)

repack = te.compute((N, chunks, H, W, block_size), _reorder_data_nchw, name="repack")
B = te.compute(
(N, C, H, W),
lambda n, c, h, w: repack[n, c // block_size, h, w, c % block_size],
name="back_repack",
)
s = te.create_schedule([B.op])
bounds = tvm.te.schedule.InferBound(s)
# Block for intermediate compute function should be equal to 4 for all cases except than number of channels is less than 4
assert bounds[repack.op.axis[4]].extent.value == expected

_check((1, 4, 6, 6), 4)
_check((1, 7, 6, 6), 4)
_check((1, 3, 6, 6), 3)


if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
test_bound_nest_group()
test_bound_group_schedule()
test_bound_scan()
test_bound3()
test_bound_rfactor()
test_bound_blur()
test_bound_conv1d()
test_bound2()
test_gemm_bound()
test_bound_warp()
test_bound_tensor_compute_op()
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
test_bound_split_divisible()
test_bound_tile_divisible()
tvm.testing.main()