Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix test cases
Add some version check to avoid calling the operator/attribute which has not been introduced yet!
  • Loading branch information
jikechao authored May 22, 2023
commit 8d9173a006c415a132e672bacabcb7bc471acdc0
16 changes: 11 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,11 +1012,13 @@ def __init__(self, split_size_or_sections, dim):
def forward(self, *args):
return torch.tensor_split(args[0], self.split_size_or_sections, self.dim)

input_data = torch.rand(input_shape).float()
verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data)
verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data)
# tensor_split was introduced when torch > 1.7.1
if package_version.parse(torch.__version__) > package_version.parse("1.7.1"):
input_data = torch.rand(input_shape).float()
verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data)
verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -5025,6 +5027,10 @@ def forward(self, x, y):
grid_3D = torch.rand([4, 8, 8, 8, 3]).float()

for _method in methods:
# bicubic was introduced when pytorch > 1.7.1
torch_version = package_version.parse(torch.__version__)
if _method=='bicubic' and torch_version <= package_version.parse("1.7.1"):
continue
for _padding in padding_modes:
for _align in align_corners:
# ATTENTION:
Expand Down