Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions docs/how_to/tutorials/export_and_load_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
else: # pragma: no cover
TorchMLP = None # type: ignore[misc, assignment]

if not RUN_EXAMPLE:
print("Skip model conversion because PyTorch is unavailable or we are in CI.")

if RUN_EXAMPLE:
torch_model = TorchMLP().eval()
example_args = (torch.randn(1, 1, 28, 28, dtype=torch.float32),)
Expand Down Expand Up @@ -126,7 +123,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
built_mod = pipeline(mod)

# Build without params - we'll pass them at runtime
executable = relax.build(built_mod, target=TARGET)
executable = tvm.compile(built_mod, target=TARGET)

library_path = ARTIFACT_DIR / "mlp_cpu.so"
executable.export_library(str(library_path), workspace_dir=str(ARTIFACT_DIR))
Expand Down Expand Up @@ -176,7 +173,10 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
# TVM returns Array objects for tuple outputs, access via indexing.
# For models imported from PyTorch, outputs are typically tuples (even for single outputs).
# For ONNX models, outputs may be a single Tensor directly.
result_tensor = tvm_output[0] if isinstance(tvm_output, (tuple, list)) else tvm_output
if isinstance(tvm_output, tvm.ir.Array) and len(tvm_output) > 0:
result_tensor = tvm_output[0]
else:
result_tensor = tvm_output

print("VM output shape:", result_tensor.shape)
print("VM output type:", type(tvm_output), "->", type(result_tensor))
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
#
# mod = from_exported_program(exported_program, keep_params_as_input=False)
# # Parameters are now embedded as constants in the module
# executable = relax.build(built_mod, target=TARGET)
# executable = tvm.compile(built_mod, target=TARGET)
# # Runtime: vm["main"](input) # No need to pass params!
#
# This creates a single-file deployment (only the ``.so`` is needed), but you
Expand Down Expand Up @@ -262,7 +262,10 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
#
# # Step 6: Extract result (output may be tuple or single Tensor)
# # PyTorch models typically return tuples, ONNX models may return a single Tensor
# result = output[0] if isinstance(output, (tuple, list)) else output
# if isinstance(tvm_output, tvm.ir.Array) and len(tvm_output) > 0:
# result_tensor = tvm_output[0]
# else:
# result_tensor = tvm_output
#
# print("Prediction shape:", result.shape)
# print("Predicted class:", np.argmax(result.numpy()))
Expand Down Expand Up @@ -332,7 +335,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore[override]
#
# # Step 1: Cross-compile for ARM target (on local machine)
# TARGET = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu")
# executable = relax.build(built_mod, target=TARGET)
# executable = tvm.compile(built_mod, target=TARGET)
# executable.export_library("mlp_arm.so")
#
# # Step 2: Connect to remote device RPC server
Expand Down
Loading