Skip to content
Merged
Prev Previous commit
Next Next commit
linting
  • Loading branch information
hugolatendresse committed Apr 14, 2025
commit 35aee297ba2ca01dbdf2695267cf1869b399a95b
5 changes: 3 additions & 2 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ def __init__(self):

def forward(self, x):
return torch.full((2, 3), 3.141592)

torch_module = FullModel().eval()

raw_data = np.random.rand(3,3).astype("float32")
raw_data = np.random.rand(3, 3).astype("float32")

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_clamp(target, dev):
class ClampBothTensor(torch.nn.Module):
Expand Down