Skip to content

Commit aa61558

Browse files
committed
Merged in SYSOL-584-pytorch-conv-transpose-pad-fix (pull request apache#34)
SYSOL-584 Pytorch Conv Transpose Padding Fix Approved-by: Alicja Kwasniewska Approved-by: Mikael Sevenier
2 parents 72f957b + 213886b commit aa61558

2 files changed

Lines changed: 90 additions & 20 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -970,32 +970,52 @@ def convolution(self, inputs, input_types):
970970
kernel_size = weight_shape[2:]
971971
use_bias = isinstance(bias, _expr.Expr)
972972

973-
if len(kernel_size) == 1:
974-
strides = (1,) + strides
975-
padding = (0,) + padding
976-
dilation = (1,) + dilation
973+
# We are trying to invoke various relay operations through a single conv_op variable.
974+
# However the function signatures for some operations have additional attributes so we
975+
# pass these in along with the standard ones.
976+
additional_arguments = dict()
977977

978978
if use_transpose:
979979
if len(kernel_size) == 3:
980980
conv_op = _op.nn.conv3d_transpose
981-
else:
981+
elif len(kernel_size) == 2:
982982
conv_op = _op.nn.conv2d_transpose
983+
else:
984+
conv_op = _op.nn.conv1d_transpose
985+
output_padding = tuple(inputs[7])
986+
additional_arguments['output_padding'] = output_padding
987+
983988
else:
984989
if len(kernel_size) == 3:
985990
conv_op = _op.nn.conv3d
986-
else:
991+
elif len(kernel_size) == 2:
987992
conv_op = _op.nn.conv2d
993+
else:
994+
conv_op = _op.nn.conv1d
988995

989996
if len(kernel_size) == 3:
990997
data_layout = "NCDHW"
991998
kernel_layout = "OIDHW"
992-
else:
999+
elif len(kernel_size) == 2:
9931000
data_layout = "NCHW"
9941001
kernel_layout = "OIHW"
995-
996-
if len(kernel_size) == 1:
1002+
else:
1003+
data_layout = "NCW"
1004+
kernel_layout = "OIW"
1005+
1006+
# Conv1d does not currently support grouped convolution so we convert it to conv2d
1007+
is_grouped_conv1d = False
1008+
if groups > 1 and len(kernel_size) == 1 and not use_transpose:
1009+
is_grouped_conv1d = True
1010+
conv_op = _op.nn.conv2d
1011+
kernel_size = [1] + kernel_size
1012+
strides = (1,) + strides
1013+
padding = (0,) + padding
1014+
dilation = (1,) + dilation
9971015
data = _op.expand_dims(data, axis=2)
9981016
weight = _op.expand_dims(weight, axis=2)
1017+
data_layout = "NCHW"
1018+
kernel_layout = "OIHW"
9991019

10001020
conv_out = conv_op(
10011021
data,
@@ -1005,17 +1025,20 @@ def convolution(self, inputs, input_types):
10051025
dilation=dilation,
10061026
groups=groups,
10071027
channels=channels,
1008-
kernel_size=[1] + kernel_size if len(kernel_size) == 1 else kernel_size,
1028+
kernel_size=kernel_size,
10091029
data_layout=data_layout,
10101030
kernel_layout=kernel_layout,
10111031
out_layout="",
10121032
out_dtype="",
1033+
**additional_arguments,
10131034
)
10141035
if use_bias:
10151036
res = _op.nn.bias_add(conv_out, bias)
10161037
else:
10171038
res = conv_out
1018-
if len(kernel_size) == 1:
1039+
if is_grouped_conv1d:
1040+
# Because we conducted grouped conv1d convolution through conv2d we must
1041+
# squeeze the output to get the correct result.
10191042
res = _op.squeeze(res, axis=[2])
10201043
return res
10211044

tests/python/frontend/pytorch/test_forward.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tvm.contrib.nvcc import have_fp16
3232
import tvm.testing
3333
from packaging import version as package_version
34+
import pytest
3435

3536
sys.setrecursionlimit(10000)
3637

@@ -961,17 +962,63 @@ def forward(self, *args):
961962

962963

963964
@tvm.testing.uses_gpu
964-
def test_forward_conv_transpose():
965-
torch.set_grad_enabled(False)
966-
conv2d_input_shape = [1, 3, 10, 10]
965+
@pytest.mark.parametrize("in_channels", [3], ids=lambda x: 'in_channels=' + str(x))
966+
@pytest.mark.parametrize("out_channels", [5], ids=lambda x: 'out_channels=' + str(x))
967+
@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: 'kernel_size=' + str(x))
968+
@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: 'output_padding=' + str(x))
969+
@pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x))
970+
@pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x))
971+
def test_forward_conv_transpose(in_channels,
972+
out_channels,
973+
kernel_size,
974+
output_padding,
975+
bias,
976+
groups):
977+
# Note we do not test with groups > 1 because that is not supported
978+
# in tvm for conv transpose operations
979+
980+
# Output padding must be smaller than either stride or dilation so we
981+
# opt to make the stride 1 + output padding
982+
stride = output_padding + 1
983+
984+
#Conv 3D Transpose Tests
985+
conv3d_input_shape = [1, in_channels, 16, 16, 16]
986+
conv3d_input_data = torch.rand(conv3d_input_shape).float()
987+
conv3d_transpose = torch.nn.ConvTranspose3d(in_channels=in_channels,
988+
out_channels=out_channels,
989+
kernel_size=kernel_size,
990+
stride=stride,
991+
output_padding=output_padding,
992+
groups=groups,
993+
bias=bias,
994+
).eval()
995+
verify_model(conv3d_transpose, conv3d_input_data)
996+
997+
# Conv 2D Transpose Tests
998+
conv2d_input_shape = [1, in_channels, 128, 256]
967999
conv2d_input_data = torch.rand(conv2d_input_shape).float()
968-
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data)
969-
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data)
970-
971-
conv1d_input_shape = [1, 3, 10]
1000+
conv2d_transpose = torch.nn.ConvTranspose2d(in_channels=in_channels,
1001+
out_channels=out_channels,
1002+
kernel_size=kernel_size,
1003+
stride=stride,
1004+
output_padding=output_padding,
1005+
groups=groups,
1006+
bias=bias,
1007+
).eval()
1008+
verify_model(conv2d_transpose, conv2d_input_data)
1009+
1010+
# # Conv 1D Transpose Tests
1011+
conv1d_input_shape = [1, in_channels, 10]
9721012
conv1d_input_data = torch.rand(conv1d_input_shape).float()
973-
verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data)
974-
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)
1013+
conv1d_transpose = torch.nn.ConvTranspose1d(in_channels=in_channels,
1014+
out_channels=out_channels,
1015+
kernel_size=kernel_size,
1016+
stride=stride,
1017+
output_padding=output_padding,
1018+
groups=groups,
1019+
bias=bias,
1020+
).eval()
1021+
verify_model(conv1d_transpose, conv1d_input_data)
9751022

9761023

9771024
def test_forward_deform_conv():

0 commit comments

Comments
 (0)