|
31 | 31 | from tvm.contrib.nvcc import have_fp16 |
32 | 32 | import tvm.testing |
33 | 33 | from packaging import version as package_version |
| 34 | +import pytest |
34 | 35 |
|
35 | 36 | sys.setrecursionlimit(10000) |
36 | 37 |
|
@@ -961,17 +962,63 @@ def forward(self, *args): |
961 | 962 |
|
962 | 963 |
|
963 | 964 | @tvm.testing.uses_gpu |
964 | | -def test_forward_conv_transpose(): |
965 | | - torch.set_grad_enabled(False) |
966 | | - conv2d_input_shape = [1, 3, 10, 10] |
| 965 | +@pytest.mark.parametrize("in_channels", [3], ids=lambda x: 'in_channels=' + str(x)) |
| 966 | +@pytest.mark.parametrize("out_channels", [5], ids=lambda x: 'out_channels=' + str(x)) |
| 967 | +@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: 'kernel_size=' + str(x)) |
| 968 | +@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: 'output_padding=' + str(x)) |
| 969 | +@pytest.mark.parametrize("groups", [1], ids=lambda x: 'groups=' + str(x)) |
| 970 | +@pytest.mark.parametrize("bias", [True, False], ids=lambda x: 'bias=' + str(x)) |
| 971 | +def test_forward_conv_transpose(in_channels, |
| 972 | + out_channels, |
| 973 | + kernel_size, |
| 974 | + output_padding, |
| 975 | + bias, |
| 976 | + groups): |
| 977 | + # Note we do not test with groups > 1 because that is not supported |
| 978 | + # in tvm for conv transpose operations |
| 979 | + |
| 980 | + # Output padding must be smaller than either stride or dilation so we |
| 981 | + # opt to make the stride 1 + output padding |
| 982 | + stride = output_padding + 1 |
| 983 | + |
| 984 | + #Conv 3D Transpose Tests |
| 985 | + conv3d_input_shape = [1, in_channels, 16, 16, 16] |
| 986 | + conv3d_input_data = torch.rand(conv3d_input_shape).float() |
| 987 | + conv3d_transpose = torch.nn.ConvTranspose3d(in_channels=in_channels, |
| 988 | + out_channels=out_channels, |
| 989 | + kernel_size=kernel_size, |
| 990 | + stride=stride, |
| 991 | + output_padding=output_padding, |
| 992 | + groups=groups, |
| 993 | + bias=bias, |
| 994 | + ).eval() |
| 995 | + verify_model(conv3d_transpose, conv3d_input_data) |
| 996 | + |
| 997 | + # Conv 2D Transpose Tests |
| 998 | + conv2d_input_shape = [1, in_channels, 128, 256] |
967 | 999 | conv2d_input_data = torch.rand(conv2d_input_shape).float() |
968 | | - verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data) |
969 | | - verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data) |
970 | | - |
971 | | - conv1d_input_shape = [1, 3, 10] |
| 1000 | + conv2d_transpose = torch.nn.ConvTranspose2d(in_channels=in_channels, |
| 1001 | + out_channels=out_channels, |
| 1002 | + kernel_size=kernel_size, |
| 1003 | + stride=stride, |
| 1004 | + output_padding=output_padding, |
| 1005 | + groups=groups, |
| 1006 | + bias=bias, |
| 1007 | + ).eval() |
| 1008 | + verify_model(conv2d_transpose, conv2d_input_data) |
| 1009 | + |
| 1010 | + # # Conv 1D Transpose Tests |
| 1011 | + conv1d_input_shape = [1, in_channels, 10] |
972 | 1012 | conv1d_input_data = torch.rand(conv1d_input_shape).float() |
973 | | - verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data) |
974 | | - verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data) |
| 1013 | + conv1d_transpose = torch.nn.ConvTranspose1d(in_channels=in_channels, |
| 1014 | + out_channels=out_channels, |
| 1015 | + kernel_size=kernel_size, |
| 1016 | + stride=stride, |
| 1017 | + output_padding=output_padding, |
| 1018 | + groups=groups, |
| 1019 | + bias=bias, |
| 1020 | + ).eval() |
| 1021 | + verify_model(conv1d_transpose, conv1d_input_data) |
975 | 1022 |
|
976 | 1023 |
|
977 | 1024 | def test_forward_deform_conv(): |
|
0 commit comments