Cannot import mobilenet_v3 because Hardswish and Hardsigmoid are not supported by Relax. I'll try to fix it.
TODOs
Expected behavior
mobilenet_v3_small and mobilenet_v3_small can be imported with from_fx.
Actual behavior
Got the below error message when I executed the repro.
$ python compile_mobilenet_v3.py
Traceback (most recent call last):
File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 34, in <module>
main()
File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 21, in main
mod = from_fx(graph_model, [(inp.shape, "float32")])
File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1698, in from_fx
return TorchFXImporter().from_fx(
File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1570, in from_fx
type(module) in self.convert_map
AssertionError: Unsupported module type <class 'torch.nn.modules.activation.Hardswish'>
[20:07:07] /home/ubuntu/data/sandbox/.dep/tvm/src/relax/ir/block_builder.cc:66: Warning: BlockBuilder destroyed with remaining blocks!
Environment
OS: Ubuntu 22.04 LTS on WSL2
TVM: 0e622e1
PyTorch: 2.3.0
Torchvision: 0.18.0
Steps to reproduce
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
import torch
import torchvision
def main():
model_name = "mobilenet_v3_small" # mobilenet_v3_small or mobilenet_v3_large
inp = torch.rand(8, 3, 224, 224)
weights = torchvision.models.get_model_weights(model_name).DEFAULT
model_pth = torchvision.models.get_model(model_name, weights=weights).eval()
# PyTorch
output_pth = model_pth(inp)
# TVM
graph_model = torch.fx.symbolic_trace(model_pth)
with torch.no_grad():
mod = from_fx(graph_model, [(inp.shape, "float32")])
target = tvm.target.Target("llvm", host="llvm")
mod = relax.transform.DecomposeOpsForInference()(mod)
mod = relax.transform.LegalizeOps()(mod)
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
output_tvm = torch.tensor(vm["main"](tvm.nd.array(inp.detach().numpy())).numpy())
torch.testing.assert_close(output_pth, output_tvm, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
main()
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
cc @junrushao
Cannot import mobilenet_v3 because Hardswish and Hardsigmoid are not supported by Relax. I'll try to fix it.
TODOs
Expected behavior
mobilenet_v3_small and mobilenet_v3_small can be imported with
from_fx.Actual behavior
Got the below error message when I executed the repro.
Environment
OS: Ubuntu 22.04 LTS on WSL2
TVM: 0e622e1
PyTorch: 2.3.0
Torchvision: 0.18.0
Steps to reproduce
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
cc @junrushao