From 7a08cd241b4d0bef9cf81f79496418be9d174170 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 27 Mar 2026 22:48:19 -0400 Subject: [PATCH 1/6] finish1 --- .../mix_python_and_tvm_with_pymodule.py | 317 ++++++++++++++++++ python/tvm/relax/base_py_module.py | 9 +- 2 files changed, 318 insertions(+), 8 deletions(-) create mode 100644 docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py new file mode 100644 index 000000000000..fc01af07795c --- /dev/null +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa: E402 + +""" +.. _mix_python_and_tvm: + +Mix Python/PyTorch with TVM Using BasePyModule +=============================================== +This tutorial shows how to mix Python functions, TIR kernels, and Relax graph-level functions +in a single ``IRModule`` using the ``BasePyModule`` system. The key benefits are: + +- **Debug without compiling**: Run IRModules directly in Python, calling TIR and Relax functions + through JIT compilation while keeping Python functions as-is. +- **PyTorch interop**: Use PyTorch operators as fallbacks for ops TVM does not yet support, + with zero-copy DLPack tensor conversion. +- **Relax-to-Python conversion**: Automatically translate compiled Relax functions into equivalent + PyTorch code for numerical verification at any compilation stage. + +.. contents:: Table of Contents + :local: + :depth: 1 +""" + +###################################################################### +# Preparation +# ----------- +# We import the necessary modules. ``BasePyModule`` is the base class that enables Python function +# integration with TVM's IRModule. The ``I``, ``T``, ``R`` namespaces provide TVMScript decorators +# for IR modules, TIR functions, and Relax functions respectively. + +import os + +try: + import torch +except ImportError: + torch = None + +import tvm +from tvm.relax.base_py_module import BasePyModule +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tirx as T + +IS_IN_CI = os.getenv("CI", "").lower() == "true" +HAS_TORCH = torch is not None +RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI + + +###################################################################### +# Part 1: BasePyModule Basics +# ---------------------------- +# A ``BasePyModule`` wraps an ``IRModule`` and provides: +# +# - Automatic JIT compilation of TIR and Relax functions +# - DLPack-based zero-copy conversion between PyTorch tensors and TVM NDArrays +# - A unified interface where Python, TIR, and Relax functions coexist +# +# Let us start with a simple example: a module that contains one TIR function (element-wise add) +# and one Python function that orchestrates the computation using PyTorch tensors. + +if RUN_EXAMPLE: + + @I.ir_module + class MyModule(BasePyModule): + @I.pyfunc + def forward(self, x, y): + """Python function: receives PyTorch tensors, calls TIR, returns PyTorch tensors.""" + # Convert PyTorch tensors to TVM NDArrays (zero-copy via DLPack) + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + + # Call the TIR function below + result = self.call_tir( + self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32") + ) + + # Convert back to PyTorch (zero-copy via DLPack) + return self._convert_tvm_to_pytorch(result) + + @T.prim_func + def add_tir( + A: T.Buffer((4,), "float32"), + B: T.Buffer((4,), "float32"), + C: T.Buffer((4,), "float32"), + ): + for i in range(4): + C[i] = A[i] + B[i] + + # Instantiate the module on CPU. TIR functions are JIT-compiled at this point. + mod = MyModule(device=tvm.cpu(0)) + + # Call the Python function with PyTorch tensors + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + y = torch.tensor([10.0, 20.0, 30.0, 40.0]) + result = mod.forward(x, y) + + print("Input x:", x) + print("Input y:", y) + print("Result (x + y via TIR):", result) + assert torch.allclose(result, x + y) + + # BasePyModule also supports pretty-printing via show(), including Python functions + print("\n=== Module TVMScript ===") + mod.show() + + +###################################################################### +# How it Works +# ~~~~~~~~~~~~ +# When the class is decorated with ``@I.ir_module`` and inherits from ``BasePyModule``: +# +# 1. Methods decorated with ``@T.prim_func`` and ``@R.function`` are parsed into TIR/Relax IR +# and stored in the underlying ``IRModule``. +# 2. Methods decorated with ``@I.pyfunc`` are registered as Python functions. +# 3. On instantiation (``MyModule(device=...)``), TIR functions are compiled via ``tvm.compile`` +# and Relax functions are loaded into a ``VirtualMachine``. Python functions remain as-is. +# 4. ``call_tir`` handles DLPack conversion, output allocation, and calling the compiled kernel. +# + + +###################################################################### +# Part 2: Mixing TIR, Relax, and Python +# ---------------------------------------- +# A single module can contain all three kinds of functions. This is useful when some operations +# are best expressed as low-level TIR kernels, others as high-level Relax graphs, and some +# require Python-level logic (e.g., dynamic control flow, calling external libraries). + +if RUN_EXAMPLE: + + @I.ir_module + class HybridModule(BasePyModule): + @I.pyfunc + def preprocess(self, x): + """Use PyTorch for preprocessing — e.g., normalization.""" + mean = x.mean() + std = x.std() + return (x - mean) / (std + 1e-5) + + @I.pyfunc + def run_pipeline(self, x): + """Orchestrate: Python preprocessing -> TIR computation -> result.""" + # Step 1: Python-based preprocessing + normalized = self.preprocess(x) + + # Step 2: Convert and run TIR kernel + tvm_input = self._convert_pytorch_to_tvm(normalized) + tvm_result = self.call_tir( + self.scale_tir, [tvm_input], out_sinfo=R.Tensor((4,), "float32") + ) + return self._convert_tvm_to_pytorch(tvm_result) + + @T.prim_func + def scale_tir(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + for i in range(4): + B[i] = A[i] * T.float32(2.0) + + mod = HybridModule(device=tvm.cpu(0)) + + x = torch.tensor([1.0, 3.0, 5.0, 7.0]) + result = mod.run_pipeline(x) + print("Pipeline result:", result) + + +###################################################################### +# Part 3: Adding Python Functions Dynamically +# --------------------------------------------- +# You can also register Python functions after module creation using ``add_python_function``. +# This is useful for attaching PyTorch-based fallback operators or custom post-processing. + +if RUN_EXAMPLE: + py_mod = BasePyModule(tvm.IRModule({}), device=tvm.cpu(0)) + + # Dynamically add a Python function + def my_activation(x): + """Custom activation using PyTorch.""" + return torch.relu(x) + torch.tanh(x) + + py_mod.add_python_function("my_activation", my_activation) + + # Now we can call it as a method + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = py_mod.my_activation(x) + print("Custom activation:", result) + expected = torch.relu(x) + torch.tanh(x) + assert torch.allclose(result, expected) + + +###################################################################### +# Part 4: Relax-to-Python Function Converter +# -------------------------------------------- +# One powerful feature is the ability to automatically convert Relax functions into equivalent +# PyTorch code. This is useful for: +# +# - Numerically verifying Relax IR against PyTorch after applying optimization passes +# - Debugging: inspect what a Relax function actually computes by running it in Python +# - Prototyping: test Relax graph transformations without a full compilation cycle +# +# The ``RelaxToPyFuncConverter`` maps 300+ Relax operators to their PyTorch equivalents. + +if RUN_EXAMPLE: + from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter + + @I.ir_module + class RelaxModel: + @T.prim_func + def custom_add(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + for i in range(5): + out[i] = x[i] + y[i] + + @R.function + def main( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + # Mix of Relax ops and TIR calls + added = R.add(x, y) + activated = R.nn.relu(added) + cls = RelaxModel + result = R.call_tir(cls.custom_add, (activated, y), out_sinfo=R.Tensor((5,), "float32")) + return result + + # Convert the Relax function "main" to an equivalent Python/PyTorch function + converter = RelaxToPyFuncConverter(RelaxModel) + converted_mod = converter.convert(["main"]) + + # The converted function lives in ir_mod.pyfuncs and accepts PyTorch tensors directly + x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0]) + y = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5]) + + py_result = converted_mod.pyfuncs["main"](x, y) + + # Manually compute the expected result for verification + step1 = torch.add(x, y) # [1.5, -1.5, 3.5, -3.5, 5.5] + step2 = torch.relu(step1) # [1.5, 0.0, 3.5, 0.0, 5.5] + expected = step2 + y # [2.0, 0.5, 4.0, 0.5, 6.0] + + print("Relax function converted to Python:") + print(" Input x:", x) + print(" Input y:", y) + print(" Python result:", py_result) + print(" Expected: ", expected) + assert torch.allclose(py_result, expected) + + +###################################################################### +# Part 5: Using R.call_py_func in Relax IR +# ------------------------------------------ +# ``R.call_py_func`` lets you embed Python function calls directly inside Relax IR. This means +# the compiled Relax VM can call back into Python at runtime. This is the bridge for ops that +# TVM cannot compile natively — the rest of the graph is compiled and optimized, while specific +# ops fall back to Python/PyTorch. +# +# .. note:: +# ``R.call_py_func`` adds runtime overhead due to the Python-TVM boundary crossing. +# Use it for prototyping or for ops that are not performance-critical. +# +# Here is an example using ``call_py_func`` inside a Relax function: +# +# .. code-block:: python +# +# @I.ir_module +# class CallPyFuncModule(BasePyModule): +# @I.pyfunc +# def my_custom_op(self, x): +# """Python fallback for a custom op.""" +# return torch.sigmoid(x) * x # SiLU / Swish activation +# +# @R.function +# def main(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"): +# # Call the Python function from within Relax IR +# result = R.call_py_func( +# "my_custom_op", (x,), out_sinfo=R.Tensor((4,), "float32") +# ) +# return result +# +# mod = CallPyFuncModule(device=tvm.cpu(0)) +# x = torch.tensor([1.0, -1.0, 2.0, -2.0]) +# result = mod.main(x) +# +# The VM executes the compiled Relax bytecode, and when it hits ``call_py_func``, it looks up +# the registered Python function by name and calls it with DLPack-converted tensors. + + +###################################################################### +# Summary +# ------- +# This tutorial covered the Relax Python Module system: +# +# - **BasePyModule**: A base class that unifies Python, TIR, and Relax functions in one module, +# with JIT compilation and DLPack-based tensor conversion. +# - **@I.pyfunc**: Decorator to mark Python functions inside an ``@I.ir_module`` class. +# - **Dynamic registration**: ``add_python_function()`` to attach Python functions after creation. +# - **RelaxToPyFuncConverter**: Automatically converts Relax functions to PyTorch for debugging +# and numerical verification. +# - **R.call_py_func**: Embeds Python function calls in Relax IR, enabling fallback to PyTorch +# for unsupported ops while keeping the rest of the graph compiled. +# +# Together, these features make TVM a hybrid execution framework where you can freely mix +# compiled TVM code and Python/PyTorch, enabling faster iteration during development +# and gradual migration of ops from Python fallbacks to optimized TVM kernels. diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 7840e5b3b4c6..67b676163326 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -146,14 +146,7 @@ def _compile_functions(self): except Exception as error: print(f"Warning: Failed to compile one or more TIR functions: {error}") - relax_mod = tvm.IRModule( - { - gv: func - for gv, func in self.ir_mod.functions_items() - if isinstance(func, relax.Function) - } - ) - if relax_mod: + if self.relax_func_names: try: exec_mod = tvm.compile(self.ir_mod, target=self.target) self.relax_vm = relax.VirtualMachine(exec_mod, self.device) From d90354bf9ec85641e34d829936cd3c4a102aa78b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 27 Mar 2026 23:13:02 -0400 Subject: [PATCH 2/6] finish2 --- .../mix_python_and_tvm_with_pymodule.py | 64 +++++++++++-------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index fc01af07795c..38c7e0f460ce 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -270,32 +270,44 @@ def main( # .. note:: # ``R.call_py_func`` adds runtime overhead due to the Python-TVM boundary crossing. # Use it for prototyping or for ops that are not performance-critical. -# -# Here is an example using ``call_py_func`` inside a Relax function: -# -# .. code-block:: python -# -# @I.ir_module -# class CallPyFuncModule(BasePyModule): -# @I.pyfunc -# def my_custom_op(self, x): -# """Python fallback for a custom op.""" -# return torch.sigmoid(x) * x # SiLU / Swish activation -# -# @R.function -# def main(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"): -# # Call the Python function from within Relax IR -# result = R.call_py_func( -# "my_custom_op", (x,), out_sinfo=R.Tensor((4,), "float32") -# ) -# return result -# -# mod = CallPyFuncModule(device=tvm.cpu(0)) -# x = torch.tensor([1.0, -1.0, 2.0, -2.0]) -# result = mod.main(x) -# -# The VM executes the compiled Relax bytecode, and when it hits ``call_py_func``, it looks up -# the registered Python function by name and calls it with DLPack-converted tensors. + +if RUN_EXAMPLE: + + @I.ir_module + class CallPyFuncModule(BasePyModule): + @I.pyfunc + def torch_relu(self, x): + """Python fallback: PyTorch ReLU.""" + return torch.relu(x) + + @I.pyfunc + def torch_softmax(self, x, dim=0): + """Python fallback: PyTorch softmax.""" + return torch.softmax(x, dim=dim) + + @R.function + def main(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): + # The VM calls back into Python for these ops at runtime + relu_result = R.call_py_func( + "torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32") + ) + result = R.call_py_func( + "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + ) + return result + + mod = CallPyFuncModule(device=tvm.cpu(0)) + + x = torch.randn(10, dtype=torch.float32) + + # call_py_func can be called directly from Python as well + relu_result = mod.call_py_func("torch_relu", [x]) + result = mod.call_py_func("torch_softmax", [relu_result]) + + expected = torch.softmax(torch.relu(x), dim=0) + print("R.call_py_func result:", result) + print("PyTorch expected: ", expected) + assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5) ###################################################################### From 3ccc7b028c5ac8232e4b8a6d4a048d80f6403e9e Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 28 Mar 2026 01:21:29 -0400 Subject: [PATCH 3/6] finsh4 --- .../mix_python_and_tvm_with_pymodule.py | 423 ++++++++++-------- 1 file changed, 241 insertions(+), 182 deletions(-) diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index 38c7e0f460ce..255a818c3f20 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -21,15 +21,23 @@ Mix Python/PyTorch with TVM Using BasePyModule =============================================== -This tutorial shows how to mix Python functions, TIR kernels, and Relax graph-level functions -in a single ``IRModule`` using the ``BasePyModule`` system. The key benefits are: +In a typical TVM workflow, you write an ``IRModule``, compile it, and load the compiled artifact +into a ``VirtualMachine`` to run. This means **you cannot test or debug anything until the entire +module compiles successfully**. If a single op is unsupported, the whole pipeline is blocked. -- **Debug without compiling**: Run IRModules directly in Python, calling TIR and Relax functions - through JIT compilation while keeping Python functions as-is. -- **PyTorch interop**: Use PyTorch operators as fallbacks for ops TVM does not yet support, - with zero-copy DLPack tensor conversion. -- **Relax-to-Python conversion**: Automatically translate compiled Relax functions into equivalent - PyTorch code for numerical verification at any compilation stage. +``BasePyModule`` solves this by letting Python functions, TIR kernels, and Relax functions coexist +in one module. TIR and Relax functions are JIT-compiled, Python functions run as-is, and tensors +move between TVM and PyTorch via zero-copy DLPack. This enables: + +- **Incremental development**: get a model running with Python fallbacks first, then replace them + with TVM ops one by one. +- **Debugging at any stage**: convert Relax functions back to PyTorch to verify numerical + correctness after applying optimization passes. +- **Hybrid execution**: let the compiled VM call back into Python for ops that are hard to + express in TIR or Relax. + +This tutorial walks through a concrete example: building a small model where Python, TIR, and +Relax functions work together, then using the converter and ``call_py_func`` to debug and extend it. .. contents:: Table of Contents :local: @@ -39,18 +47,17 @@ ###################################################################### # Preparation # ----------- -# We import the necessary modules. ``BasePyModule`` is the base class that enables Python function -# integration with TVM's IRModule. The ``I``, ``T``, ``R`` namespaces provide TVMScript decorators -# for IR modules, TIR functions, and Relax functions respectively. import os try: import torch + import torch.nn.functional as F except ImportError: torch = None import tvm +from tvm import relax from tvm.relax.base_py_module import BasePyModule from tvm.script import ir as I from tvm.script import relax as R @@ -62,36 +69,23 @@ ###################################################################### -# Part 1: BasePyModule Basics -# ---------------------------- -# A ``BasePyModule`` wraps an ``IRModule`` and provides: +# Step 1: Your First Hybrid Module +# ---------------------------------- +# The core idea: decorate a class with ``@I.ir_module``, inherit from ``BasePyModule``, and use +# three decorators for three kinds of functions: # -# - Automatic JIT compilation of TIR and Relax functions -# - DLPack-based zero-copy conversion between PyTorch tensors and TVM NDArrays -# - A unified interface where Python, TIR, and Relax functions coexist +# - ``@T.prim_func`` — low-level TIR kernel (compiled) +# - ``@R.function`` — high-level Relax graph (compiled) +# - ``@I.pyfunc`` — plain Python (runs as-is, can call PyTorch) # -# Let us start with a simple example: a module that contains one TIR function (element-wise add) -# and one Python function that orchestrates the computation using PyTorch tensors. +# On instantiation, TIR and Relax functions are JIT-compiled. Python functions stay in Python. +# ``call_tir`` bridges them: it converts PyTorch tensors to TVM via DLPack, allocates the output, +# calls the compiled kernel, and converts back. if RUN_EXAMPLE: @I.ir_module - class MyModule(BasePyModule): - @I.pyfunc - def forward(self, x, y): - """Python function: receives PyTorch tensors, calls TIR, returns PyTorch tensors.""" - # Convert PyTorch tensors to TVM NDArrays (zero-copy via DLPack) - x_tvm = self._convert_pytorch_to_tvm(x) - y_tvm = self._convert_pytorch_to_tvm(y) - - # Call the TIR function below - result = self.call_tir( - self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32") - ) - - # Convert back to PyTorch (zero-copy via DLPack) - return self._convert_tvm_to_pytorch(result) - + class MyFirstModule(BasePyModule): @T.prim_func def add_tir( A: T.Buffer((4,), "float32"), @@ -101,116 +95,119 @@ def add_tir( for i in range(4): C[i] = A[i] + B[i] - # Instantiate the module on CPU. TIR functions are JIT-compiled at this point. - mod = MyModule(device=tvm.cpu(0)) + @I.pyfunc + def forward(self, x, y): + """Takes PyTorch tensors, calls TIR, returns PyTorch tensors.""" + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir( + self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + # TIR functions are JIT-compiled here + mod = MyFirstModule(device=tvm.cpu(0)) - # Call the Python function with PyTorch tensors x = torch.tensor([1.0, 2.0, 3.0, 4.0]) y = torch.tensor([10.0, 20.0, 30.0, 40.0]) result = mod.forward(x, y) - print("Input x:", x) - print("Input y:", y) - print("Result (x + y via TIR):", result) + print("forward(x, y) =", result) assert torch.allclose(result, x + y) - # BasePyModule also supports pretty-printing via show(), including Python functions - print("\n=== Module TVMScript ===") + # show() prints the TVMScript representation, including Python functions as ExternFunc mod.show() ###################################################################### -# How it Works -# ~~~~~~~~~~~~ -# When the class is decorated with ``@I.ir_module`` and inherits from ``BasePyModule``: +# Step 2: A Realistic Pipeline — Python, TIR, and Relax Together +# ----------------------------------------------------------------- +# Real models are not just one op. Here we build a mini inference pipeline: # -# 1. Methods decorated with ``@T.prim_func`` and ``@R.function`` are parsed into TIR/Relax IR -# and stored in the underlying ``IRModule``. -# 2. Methods decorated with ``@I.pyfunc`` are registered as Python functions. -# 3. On instantiation (``MyModule(device=...)``), TIR functions are compiled via ``tvm.compile`` -# and Relax functions are loaded into a ``VirtualMachine``. Python functions remain as-is. -# 4. ``call_tir`` handles DLPack conversion, output allocation, and calling the compiled kernel. +# 1. **Python** preprocesses the input (normalization — easy in PyTorch, verbose in TIR) +# 2. **TIR** runs a hand-written matmul kernel +# 3. **Python** applies softmax via PyTorch (a temporary fallback) # - - -###################################################################### -# Part 2: Mixing TIR, Relax, and Python -# ---------------------------------------- -# A single module can contain all three kinds of functions. This is useful when some operations -# are best expressed as low-level TIR kernels, others as high-level Relax graphs, and some -# require Python-level logic (e.g., dynamic control flow, calling external libraries). +# The key point: you do not need every op to be a TIR kernel to get the module running. +# Write what you can in TIR, fall back to Python for the rest, and iterate. if RUN_EXAMPLE: @I.ir_module - class HybridModule(BasePyModule): + class InferenceModule(BasePyModule): + @T.prim_func + def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle): + n = T.int32() + A = T.match_buffer(var_A, (n, 8), "float32") + B = T.match_buffer(var_B, (8, 4), "float32") + C = T.match_buffer(var_C, (n, 4), "float32") + for i, j, k in T.grid(n, 4, 8): + with T.sblock("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @I.pyfunc def preprocess(self, x): - """Use PyTorch for preprocessing — e.g., normalization.""" - mean = x.mean() - std = x.std() - return (x - mean) / (std + 1e-5) + """Normalize input — trivial in PyTorch, annoying in TIR.""" + return (x - x.mean()) / (x.std() + 1e-5) @I.pyfunc - def run_pipeline(self, x): - """Orchestrate: Python preprocessing -> TIR computation -> result.""" - # Step 1: Python-based preprocessing - normalized = self.preprocess(x) - - # Step 2: Convert and run TIR kernel - tvm_input = self._convert_pytorch_to_tvm(normalized) - tvm_result = self.call_tir( - self.scale_tir, [tvm_input], out_sinfo=R.Tensor((4,), "float32") + def forward(self, x, weights): + # Step 1: Python preprocessing + x_norm = self.preprocess(x) + + # Step 2: TIR matmul + x_tvm = self._convert_pytorch_to_tvm(x_norm) + w_tvm = self._convert_pytorch_to_tvm(weights) + out = self.call_tir( + self.matmul_tir, + [x_tvm, w_tvm], + out_sinfo=R.Tensor((x.shape[0], 4), "float32"), ) - return self._convert_tvm_to_pytorch(tvm_result) + logits = self._convert_tvm_to_pytorch(out) - @T.prim_func - def scale_tir(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): - for i in range(4): - B[i] = A[i] * T.float32(2.0) + # Step 3: Python softmax (fallback — could be replaced with TIR later) + return F.softmax(logits, dim=-1) - mod = HybridModule(device=tvm.cpu(0)) + mod = InferenceModule(device=tvm.cpu(0)) - x = torch.tensor([1.0, 3.0, 5.0, 7.0]) - result = mod.run_pipeline(x) - print("Pipeline result:", result) + batch = torch.randn(2, 8) + weights = torch.randn(8, 4) + probs = mod.forward(batch, weights) + + print("Input shape:", batch.shape) + print("Output probs:", probs) + print("Probs sum per row:", probs.sum(dim=-1)) # should be ~1.0 + assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5) ###################################################################### -# Part 3: Adding Python Functions Dynamically -# --------------------------------------------- -# You can also register Python functions after module creation using ``add_python_function``. -# This is useful for attaching PyTorch-based fallback operators or custom post-processing. +# Step 3: Dynamic Function Registration +# ---------------------------------------- +# Sometimes you want to add a Python function after the module is created — for example, to +# swap in a different activation function or to register a custom op at runtime. Use +# ``add_python_function`` for this. if RUN_EXAMPLE: - py_mod = BasePyModule(tvm.IRModule({}), device=tvm.cpu(0)) - - # Dynamically add a Python function - def my_activation(x): - """Custom activation using PyTorch.""" - return torch.relu(x) + torch.tanh(x) + mod.add_python_function("gelu", lambda x: F.gelu(x)) - py_mod.add_python_function("my_activation", my_activation) - - # Now we can call it as a method - x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) - result = py_mod.my_activation(x) - print("Custom activation:", result) - expected = torch.relu(x) + torch.tanh(x) - assert torch.allclose(result, expected) + x = torch.randn(4) + result = mod.gelu(x) + print("Dynamically registered gelu:", result) + assert torch.allclose(result, F.gelu(x)) ###################################################################### -# Part 4: Relax-to-Python Function Converter -# -------------------------------------------- -# One powerful feature is the ability to automatically convert Relax functions into equivalent -# PyTorch code. This is useful for: -# -# - Numerically verifying Relax IR against PyTorch after applying optimization passes -# - Debugging: inspect what a Relax function actually computes by running it in Python -# - Prototyping: test Relax graph transformations without a full compilation cycle +# Step 4: Relax-to-Python Converter for Debugging +# -------------------------------------------------- +# After importing a model or applying passes, you end up with Relax IR. How do you know the IR +# is numerically correct? The ``RelaxToPyFuncConverter`` translates Relax functions into equivalent +# PyTorch code so you can compare outputs directly. # -# The ``RelaxToPyFuncConverter`` maps 300+ Relax operators to their PyTorch equivalents. +# This is especially useful after running optimization passes: convert the optimized Relax +# function back to PyTorch and compare against the original model's output. if RUN_EXAMPLE: from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter @@ -218,112 +215,174 @@ def my_activation(x): @I.ir_module class RelaxModel: @T.prim_func - def custom_add(var_x: T.handle, var_y: T.handle, var_out: T.handle): - x = T.match_buffer(var_x, (5,), "float32") - y = T.match_buffer(var_y, (5,), "float32") - out = T.match_buffer(var_out, (5,), "float32") - for i in range(5): - out[i] = x[i] + y[i] + def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (2, 4), "float32") + b = T.match_buffer(var_b, (4,), "float32") + out = T.match_buffer(var_out, (2, 4), "float32") + for i, j in T.grid(2, 4): + out[i, j] = x[i, j] + b[j] @R.function def main( - x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - # Mix of Relax ops and TIR calls - added = R.add(x, y) - activated = R.nn.relu(added) + x: R.Tensor((2, 4), "float32"), + w: R.Tensor((4, 4), "float32"), + b: R.Tensor((4,), "float32"), + ) -> R.Tensor((2, 4), "float32"): + # matmul + bias + relu — a typical dense layer + h = R.matmul(x, w) cls = RelaxModel - result = R.call_tir(cls.custom_add, (activated, y), out_sinfo=R.Tensor((5,), "float32")) - return result + h_bias = R.call_tir( + cls.bias_add_tir, (h, b), out_sinfo=R.Tensor((2, 4), "float32") + ) + return R.nn.relu(h_bias) - # Convert the Relax function "main" to an equivalent Python/PyTorch function + # Convert "main" to a Python/PyTorch function converter = RelaxToPyFuncConverter(RelaxModel) - converted_mod = converter.convert(["main"]) + converted = converter.convert(["main"]) - # The converted function lives in ir_mod.pyfuncs and accepts PyTorch tensors directly - x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0]) - y = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5]) + # Run through the converted Python function + x = torch.randn(2, 4) + w = torch.randn(4, 4) + b = torch.randn(4) - py_result = converted_mod.pyfuncs["main"](x, y) + py_result = converted.pyfuncs["main"](x, w, b) - # Manually compute the expected result for verification - step1 = torch.add(x, y) # [1.5, -1.5, 3.5, -3.5, 5.5] - step2 = torch.relu(step1) # [1.5, 0.0, 3.5, 0.0, 5.5] - expected = step2 + y # [2.0, 0.5, 4.0, 0.5, 6.0] + # Compare with manual PyTorch computation + expected = F.relu(x @ w + b) - print("Relax function converted to Python:") - print(" Input x:", x) - print(" Input y:", y) - print(" Python result:", py_result) - print(" Expected: ", expected) - assert torch.allclose(py_result, expected) + print("Converted Python result:", py_result) + print("PyTorch expected: ", expected) + assert torch.allclose(py_result, expected, atol=1e-5) ###################################################################### -# Part 5: Using R.call_py_func in Relax IR -# ------------------------------------------ -# ``R.call_py_func`` lets you embed Python function calls directly inside Relax IR. This means -# the compiled Relax VM can call back into Python at runtime. This is the bridge for ops that -# TVM cannot compile natively — the rest of the graph is compiled and optimized, while specific -# ops fall back to Python/PyTorch. +# Step 5: R.call_py_func — Python Callbacks in Compiled IR +# ----------------------------------------------------------- +# What if you want the compiled Relax VM (not just Python-side code) to call a Python function? +# ``R.call_py_func`` embeds a Python callback directly in Relax IR. The VM compiles and +# optimizes everything else, but calls back into Python for the specified op. # -# .. note:: -# ``R.call_py_func`` adds runtime overhead due to the Python-TVM boundary crossing. -# Use it for prototyping or for ops that are not performance-critical. +# Use case: your model has one custom op that is complex to implement in TIR. Compile +# everything else for performance, and let that one op run in Python. if RUN_EXAMPLE: @I.ir_module - class CallPyFuncModule(BasePyModule): + class HybridVMModule(BasePyModule): @I.pyfunc - def torch_relu(self, x): - """Python fallback: PyTorch ReLU.""" - return torch.relu(x) + def silu(self, x): + """SiLU activation — not yet a native Relax op, so we use Python.""" + return torch.sigmoid(x) * x @I.pyfunc - def torch_softmax(self, x, dim=0): - """Python fallback: PyTorch softmax.""" - return torch.softmax(x, dim=dim) + def layer_norm(self, x): + """LayerNorm — another Python fallback.""" + return F.layer_norm(x, x.shape[-1:]) @R.function - def main(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): - # The VM calls back into Python for these ops at runtime - relu_result = R.call_py_func( - "torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32") + def main(x: R.Tensor((4, 8), "float32")) -> R.Tensor((4, 8), "float32"): + h = R.call_py_func( + "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32") ) - result = R.call_py_func( - "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + out = R.call_py_func( + "silu", (h,), out_sinfo=R.Tensor((4, 8), "float32") ) - return result + return out - mod = CallPyFuncModule(device=tvm.cpu(0)) + mod = HybridVMModule(device=tvm.cpu(0)) - x = torch.randn(10, dtype=torch.float32) + x = torch.randn(4, 8) - # call_py_func can be called directly from Python as well - relu_result = mod.call_py_func("torch_relu", [x]) - result = mod.call_py_func("torch_softmax", [relu_result]) + # call_py_func is callable from Python too + result = mod.call_py_func("layer_norm", [x]) + result = mod.call_py_func("silu", [result]) - expected = torch.softmax(torch.relu(x), dim=0) - print("R.call_py_func result:", result) - print("PyTorch expected: ", expected) + expected = torch.sigmoid(F.layer_norm(x, x.shape[-1:])) * F.layer_norm( + x, x.shape[-1:] + ) + print("call_py_func result:", result) assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5) +###################################################################### +# Step 6: Symbolic Shapes — Dynamic Batch Sizes +# ------------------------------------------------ +# Real models have dynamic shapes (e.g., variable batch size). TIR and Relax functions can +# declare symbolic dimensions. ``BasePyModule`` automatically infers concrete shapes from the +# input tensors at call time. + +if RUN_EXAMPLE: + + @I.ir_module + class DynamicModule(BasePyModule): + @T.prim_func + def scale_tir(var_x: T.handle, var_out: T.handle): + n = T.int64() + x = T.match_buffer(var_x, (n,), "float32") + out = T.match_buffer(var_out, (n,), "float32") + for i in T.serial(n): + out[i] = x[i] * T.float32(2.0) + + @R.function + def add_relax( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + mod = DynamicModule(device=tvm.cpu(0), target="llvm") + + # Works with length 5 + a5 = torch.randn(5) + b5 = torch.randn(5) + out5 = mod.add_relax(a5, b5) + print("add_relax(len=5):", out5) + + # Same module, now length 10 — no recompilation needed + a10 = torch.randn(10) + b10 = torch.randn(10) + out10 = mod.add_relax(a10, b10) + print("add_relax(len=10):", out10) + + # call_tir with symbolic output shape + n = T.int64() + x7 = torch.randn(7) + scaled = mod.call_tir("scale_tir", [x7], relax.TensorStructInfo((n,), "float32")) + print("scale_tir(len=7):", scaled) + assert torch.allclose( + torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5 + ) + + ###################################################################### # Summary # ------- -# This tutorial covered the Relax Python Module system: +# Here is what each step demonstrated and which PRs implement it: +# +# +--------+----------------------------------------+---------------------+ +# | Step | What you learned | Key PRs | +# +========+========================================+=====================+ +# | 1 | ``@I.pyfunc`` + ``call_tir`` basics, | #18229, #18331 | +# | | DLPack conversion, ``show()`` | #18253 | +# +--------+----------------------------------------+---------------------+ +# | 2 | Realistic pipeline: Python preprocess | #18229 | +# | | → TIR kernel → Python fallback | | +# +--------+----------------------------------------+---------------------+ +# | 3 | ``add_python_function`` for runtime | #18229 | +# | | registration | | +# +--------+----------------------------------------+---------------------+ +# | 4 | ``RelaxToPyFuncConverter``: verify | #18269, #18301 | +# | | Relax IR numerically against PyTorch | | +# +--------+----------------------------------------+---------------------+ +# | 5 | ``R.call_py_func``: Python callbacks | #18313, #18326 | +# | | inside compiled Relax VM | | +# +--------+----------------------------------------+---------------------+ +# | 6 | Symbolic shapes for dynamic inputs | #18288 | +# +--------+----------------------------------------+---------------------+ # -# - **BasePyModule**: A base class that unifies Python, TIR, and Relax functions in one module, -# with JIT compilation and DLPack-based tensor conversion. -# - **@I.pyfunc**: Decorator to mark Python functions inside an ``@I.ir_module`` class. -# - **Dynamic registration**: ``add_python_function()`` to attach Python functions after creation. -# - **RelaxToPyFuncConverter**: Automatically converts Relax functions to PyTorch for debugging -# and numerical verification. -# - **R.call_py_func**: Embeds Python function calls in Relax IR, enabling fallback to PyTorch -# for unsupported ops while keeping the rest of the graph compiled. +# The workflow in practice: # -# Together, these features make TVM a hybrid execution framework where you can freely mix -# compiled TVM code and Python/PyTorch, enabling faster iteration during development -# and gradual migration of ops from Python fallbacks to optimized TVM kernels. +# 1. Import a model → some ops unsupported → use ``@I.pyfunc`` as Python fallbacks +# 2. Get it running end-to-end with ``BasePyModule`` +# 3. Use ``RelaxToPyFuncConverter`` to verify correctness after optimization passes +# 4. Gradually replace Python fallbacks with TIR/Relax implementations +# 5. Use ``R.call_py_func`` for ops that must stay in Python even after compilation From 65c717a38573d9430dccd3be22166f374b6a1a8c Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 28 Mar 2026 01:30:14 -0400 Subject: [PATCH 4/6] fnish10 --- .../mix_python_and_tvm_with_pymodule.py | 338 +++++++++++------- 1 file changed, 218 insertions(+), 120 deletions(-) diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index 255a818c3f20..02df6f2dd17e 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -26,18 +26,19 @@ module compiles successfully**. If a single op is unsupported, the whole pipeline is blocked. ``BasePyModule`` solves this by letting Python functions, TIR kernels, and Relax functions coexist -in one module. TIR and Relax functions are JIT-compiled, Python functions run as-is, and tensors -move between TVM and PyTorch via zero-copy DLPack. This enables: +in one module. TIR and Relax functions are JIT-compiled on instantiation, Python functions run +as-is, and tensors move between TVM and PyTorch via zero-copy DLPack. This enables: - **Incremental development**: get a model running with Python fallbacks first, then replace them with TVM ops one by one. -- **Debugging at any stage**: convert Relax functions back to PyTorch to verify numerical - correctness after applying optimization passes. +- **Easy debugging**: insert ``print`` in Python functions to inspect intermediate tensors — no + need to compile the whole module first. +- **Verification at any compilation stage**: convert Relax IR back to PyTorch to check numerical + correctness before and after optimization passes. - **Hybrid execution**: let the compiled VM call back into Python for ops that are hard to express in TIR or Relax. -This tutorial walks through a concrete example: building a small model where Python, TIR, and -Relax functions work together, then using the converter and ``call_py_func`` to debug and extend it. +This tutorial walks through the full workflow step by step. .. contents:: Table of Contents :local: @@ -74,13 +75,12 @@ # The core idea: decorate a class with ``@I.ir_module``, inherit from ``BasePyModule``, and use # three decorators for three kinds of functions: # -# - ``@T.prim_func`` — low-level TIR kernel (compiled) -# - ``@R.function`` — high-level Relax graph (compiled) -# - ``@I.pyfunc`` — plain Python (runs as-is, can call PyTorch) +# - ``@T.prim_func`` — low-level TIR kernel (JIT-compiled on instantiation) +# - ``@R.function`` — high-level Relax graph (JIT-compiled on instantiation) +# - ``@I.pyfunc`` — plain Python (runs as-is, can use any Python library) # -# On instantiation, TIR and Relax functions are JIT-compiled. Python functions stay in Python. -# ``call_tir`` bridges them: it converts PyTorch tensors to TVM via DLPack, allocates the output, -# calls the compiled kernel, and converts back. +# ``call_tir`` bridges Python and TIR: it converts PyTorch tensors to TVM NDArrays via DLPack +# (zero-copy), allocates the output buffer, calls the compiled kernel, and converts back. if RUN_EXAMPLE: @@ -105,7 +105,7 @@ def forward(self, x, y): ) return self._convert_tvm_to_pytorch(result) - # TIR functions are JIT-compiled here + # TIR functions are JIT-compiled at instantiation mod = MyFirstModule(device=tvm.cpu(0)) x = torch.tensor([1.0, 2.0, 3.0, 4.0]) @@ -115,80 +115,151 @@ def forward(self, x, y): print("forward(x, y) =", result) assert torch.allclose(result, x + y) - # show() prints the TVMScript representation, including Python functions as ExternFunc + # show() prints TVMScript including Python functions (shown as ExternFunc) mod.show() ###################################################################### -# Step 2: A Realistic Pipeline — Python, TIR, and Relax Together -# ----------------------------------------------------------------- -# Real models are not just one op. Here we build a mini inference pipeline: -# -# 1. **Python** preprocesses the input (normalization — easy in PyTorch, verbose in TIR) -# 2. **TIR** runs a hand-written matmul kernel -# 3. **Python** applies softmax via PyTorch (a temporary fallback) -# -# The key point: you do not need every op to be a TIR kernel to get the module running. -# Write what you can in TIR, fall back to Python for the rest, and iterate. +# Step 2: Debugging — The Main Selling Point +# --------------------------------------------- +# Traditional ML compilers treat computation graphs as monolithic blobs. You cannot inspect +# intermediate tensor values without compiling the entire module. With ``@I.pyfunc``, debugging +# is as simple as adding a ``print`` statement. You can also make quick edits and re-run +# immediately — no recompilation needed. if RUN_EXAMPLE: @I.ir_module - class InferenceModule(BasePyModule): + class DebugModule(BasePyModule): @T.prim_func def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle): n = T.int32() - A = T.match_buffer(var_A, (n, 8), "float32") - B = T.match_buffer(var_B, (8, 4), "float32") - C = T.match_buffer(var_C, (n, 4), "float32") - for i, j, k in T.grid(n, 4, 8): + A = T.match_buffer(var_A, (n, 4), "float32") + B = T.match_buffer(var_B, (4, 3), "float32") + C = T.match_buffer(var_C, (n, 3), "float32") + for i, j, k in T.grid(n, 3, 4): with T.sblock("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - @I.pyfunc - def preprocess(self, x): - """Normalize input — trivial in PyTorch, annoying in TIR.""" - return (x - x.mean()) / (x.std() + 1e-5) - @I.pyfunc def forward(self, x, weights): - # Step 1: Python preprocessing - x_norm = self.preprocess(x) + # Inspect input + print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}") - # Step 2: TIR matmul - x_tvm = self._convert_pytorch_to_tvm(x_norm) + # Run TIR matmul + x_tvm = self._convert_pytorch_to_tvm(x) w_tvm = self._convert_pytorch_to_tvm(weights) out = self.call_tir( self.matmul_tir, [x_tvm, w_tvm], - out_sinfo=R.Tensor((x.shape[0], 4), "float32"), + out_sinfo=R.Tensor((x.shape[0], 3), "float32"), ) logits = self._convert_tvm_to_pytorch(out) - # Step 3: Python softmax (fallback — could be replaced with TIR later) - return F.softmax(logits, dim=-1) + # Inspect intermediate value — impossible with a compiled-only workflow + print(f" [DEBUG] logits shape: {logits.shape}, " + f"min: {logits.min():.4f}, max: {logits.max():.4f}") + + result = F.softmax(logits, dim=-1) - mod = InferenceModule(device=tvm.cpu(0)) + # Verify output + print(f" [DEBUG] probs sum: {result.sum(dim=-1)}") + return result - batch = torch.randn(2, 8) - weights = torch.randn(8, 4) - probs = mod.forward(batch, weights) + mod = DebugModule(device=tvm.cpu(0)) - print("Input shape:", batch.shape) - print("Output probs:", probs) - print("Probs sum per row:", probs.sum(dim=-1)) # should be ~1.0 + print("Running with debug prints:") + probs = mod.forward(torch.randn(2, 4), torch.randn(4, 3)) assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5) +###################################################################### +# This is the workflow the RFC describes: "debugging is as simple as inserting a print statement. +# Users can also make quick, manual edits to Python functions and immediately observe the +# results." No compilation cycle, no VM loading — just Python. + + +###################################################################### +# Step 3: A Realistic Pipeline — Python, TIR, and Packed Functions +# ------------------------------------------------------------------- +# Real models combine many kinds of operations. This step builds a mini inference pipeline using +# three different calling conventions: +# +# - ``call_tir``: call a compiled TIR kernel +# - ``call_dps_packed``: call a TVM packed function (e.g., a third-party library binding) +# - Direct Python: call any PyTorch function +# +# ``call_dps_packed`` is useful for calling functions registered via ``tvm.register_global_func`` +# — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions. + +if RUN_EXAMPLE: + + # Register a packed function (simulating an external library binding) + @tvm.register_global_func("my_bias_add", override=True) + def my_bias_add(x, bias, out): + """Packed function: adds bias to each row of x.""" + import numpy as np + + x_np = x.numpy() + b_np = bias.numpy() + out_np = x_np + b_np + out[:] = out_np + + @I.ir_module + class PipelineModule(BasePyModule): + @T.prim_func + def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle): + A = T.match_buffer(var_A, (2, 4), "float32") + B = T.match_buffer(var_B, (4, 3), "float32") + C = T.match_buffer(var_C, (2, 3), "float32") + for i, j, k in T.grid(2, 3, 4): + with T.sblock("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @I.pyfunc + def forward(self, x, weights, bias): + # 1. TIR matmul + x_tvm = self._convert_pytorch_to_tvm(x) + w_tvm = self._convert_pytorch_to_tvm(weights) + h = self.call_tir( + self.matmul_tir, [x_tvm, w_tvm], + out_sinfo=R.Tensor((2, 3), "float32"), + ) + h_pt = self._convert_tvm_to_pytorch(h) + + # 2. Packed function for bias add (simulating an external library) + h_biased = self.call_dps_packed( + "my_bias_add", [h_pt, bias], + out_sinfo=R.Tensor((2, 3), "float32"), + ) + + # 3. Python/PyTorch activation + return F.relu(h_biased) + + mod = PipelineModule(device=tvm.cpu(0)) + + x = torch.randn(2, 4) + w = torch.randn(4, 3) + b = torch.randn(3) + result = mod.forward(x, w, b) + + expected = F.relu(x @ w + b) + print("Pipeline result:", result) + print("Expected: ", expected) + assert torch.allclose(result, expected, atol=1e-4) + ###################################################################### -# Step 3: Dynamic Function Registration +# Step 4: Dynamic Function Registration # ---------------------------------------- -# Sometimes you want to add a Python function after the module is created — for example, to -# swap in a different activation function or to register a custom op at runtime. Use -# ``add_python_function`` for this. +# You can register Python functions after the module is created using ``add_python_function``. +# This is useful for swapping implementations at runtime — for example, testing different +# activation functions or registering a custom op. if RUN_EXAMPLE: mod.add_python_function("gelu", lambda x: F.gelu(x)) @@ -200,20 +271,22 @@ def forward(self, x, weights): ###################################################################### -# Step 4: Relax-to-Python Converter for Debugging -# -------------------------------------------------- -# After importing a model or applying passes, you end up with Relax IR. How do you know the IR -# is numerically correct? The ``RelaxToPyFuncConverter`` translates Relax functions into equivalent -# PyTorch code so you can compare outputs directly. +# Step 5: Relax-to-Python Converter — Verify at Any Compilation Stage +# ---------------------------------------------------------------------- +# Both Relax functions and Python functions describe computational graphs. The +# ``RelaxToPyFuncConverter`` converts Relax IR into equivalent PyTorch code by mapping +# Relax operators to their PyTorch counterparts (e.g., ``R.nn.relu`` → ``F.relu``). # -# This is especially useful after running optimization passes: convert the optimized Relax -# function back to PyTorch and compare against the original model's output. +# The key insight from the RFC: **this conversion can happen at any stage of compilation**. +# You can convert early (right after import) or late (after optimization passes have +# transformed the IR), and compare the output against a PyTorch reference to catch bugs. if RUN_EXAMPLE: from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter + # A simple Relax module: matmul + bias + relu (a dense layer) @I.ir_module - class RelaxModel: + class DenseLayer: @T.prim_func def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle): x = T.match_buffer(var_x, (2, 4), "float32") @@ -228,42 +301,61 @@ def main( w: R.Tensor((4, 4), "float32"), b: R.Tensor((4,), "float32"), ) -> R.Tensor((2, 4), "float32"): - # matmul + bias + relu — a typical dense layer h = R.matmul(x, w) - cls = RelaxModel + cls = DenseLayer h_bias = R.call_tir( - cls.bias_add_tir, (h, b), out_sinfo=R.Tensor((2, 4), "float32") + cls.bias_add_tir, (h, b), + out_sinfo=R.Tensor((2, 4), "float32"), ) return R.nn.relu(h_bias) - # Convert "main" to a Python/PyTorch function - converter = RelaxToPyFuncConverter(RelaxModel) - converted = converter.convert(["main"]) + # --- Stage 1: Convert BEFORE optimization --- + converter = RelaxToPyFuncConverter(DenseLayer) + converted_early = converter.convert(["main"]) - # Run through the converted Python function x = torch.randn(2, 4) w = torch.randn(4, 4) b = torch.randn(4) - py_result = converted.pyfuncs["main"](x, w, b) - - # Compare with manual PyTorch computation + py_result_early = converted_early.pyfuncs["main"](x, w, b) expected = F.relu(x @ w + b) - print("Converted Python result:", py_result) - print("PyTorch expected: ", expected) - assert torch.allclose(py_result, expected, atol=1e-5) + print("Before optimization:") + print(" Converted result:", py_result_early) + print(" PyTorch expected:", expected) + assert torch.allclose(py_result_early, expected, atol=1e-5) + + # --- Stage 2: Apply a pass, then convert AFTER optimization --- + # Run CanonicalizeBindings to clean up the IR, then convert again + # to verify the pass did not break numerical correctness. + optimized_mod = relax.transform.CanonicalizeBindings()(DenseLayer) + + converter_late = RelaxToPyFuncConverter(optimized_mod) + converted_late = converter_late.convert(["main"]) + + py_result_late = converted_late.pyfuncs["main"](x, w, b) + + print("\nAfter CanonicalizeBindings pass:") + print(" Converted result:", py_result_late) + print(" Still matches: ", + torch.allclose(py_result_late, expected, atol=1e-5)) + assert torch.allclose(py_result_late, expected, atol=1e-5) ###################################################################### -# Step 5: R.call_py_func — Python Callbacks in Compiled IR +# Step 6: R.call_py_func — Python Callbacks in Compiled IR # ----------------------------------------------------------- -# What if you want the compiled Relax VM (not just Python-side code) to call a Python function? -# ``R.call_py_func`` embeds a Python callback directly in Relax IR. The VM compiles and -# optimizes everything else, but calls back into Python for the specified op. +# ``R.call_py_func`` embeds a Python function call directly inside Relax IR. When the module +# is compiled and run in the VM, everything else is optimized native code, but the VM calls +# back into Python for the specified ops. # -# Use case: your model has one custom op that is complex to implement in TIR. Compile -# everything else for performance, and let that one op run in Python. +# This is the "cross-level call" design from the RFC: Relax functions can invoke Python +# functions, and Python functions can invoke TIR/Relax functions. Data flows between them +# via DLPack with minimal overhead. +# +# Use case: your model has a custom op (e.g., a special normalization or a sampling step) +# that is complex to implement in TIR. Compile everything else, and let that one op stay +# in Python. if RUN_EXAMPLE: @@ -271,7 +363,7 @@ def main( class HybridVMModule(BasePyModule): @I.pyfunc def silu(self, x): - """SiLU activation — not yet a native Relax op, so we use Python.""" + """SiLU/Swish activation — using Python as fallback.""" return torch.sigmoid(x) * x @I.pyfunc @@ -280,7 +372,10 @@ def layer_norm(self, x): return F.layer_norm(x, x.shape[-1:]) @R.function - def main(x: R.Tensor((4, 8), "float32")) -> R.Tensor((4, 8), "float32"): + def main( + x: R.Tensor((4, 8), "float32"), + ) -> R.Tensor((4, 8), "float32"): + # The VM calls back into Python for these two ops h = R.call_py_func( "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32") ) @@ -290,26 +385,24 @@ def main(x: R.Tensor((4, 8), "float32")) -> R.Tensor((4, 8), "float32"): return out mod = HybridVMModule(device=tvm.cpu(0)) - x = torch.randn(4, 8) - # call_py_func is callable from Python too + # call_py_func is also callable from Python directly result = mod.call_py_func("layer_norm", [x]) result = mod.call_py_func("silu", [result]) - expected = torch.sigmoid(F.layer_norm(x, x.shape[-1:])) * F.layer_norm( - x, x.shape[-1:] - ) + ln = F.layer_norm(x, x.shape[-1:]) + expected = torch.sigmoid(ln) * ln print("call_py_func result:", result) assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5) ###################################################################### -# Step 6: Symbolic Shapes — Dynamic Batch Sizes +# Step 7: Symbolic Shapes — Dynamic Batch Sizes # ------------------------------------------------ -# Real models have dynamic shapes (e.g., variable batch size). TIR and Relax functions can -# declare symbolic dimensions. ``BasePyModule`` automatically infers concrete shapes from the -# input tensors at call time. +# Real models have dynamic shapes (e.g., variable batch size or sequence length). TIR and Relax +# functions can declare symbolic dimensions using string names like ``"n"``. ``BasePyModule`` +# automatically infers concrete shapes from the actual input tensors at call time. if RUN_EXAMPLE: @@ -325,7 +418,8 @@ def scale_tir(var_x: T.handle, var_out: T.handle): @R.function def add_relax( - x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + x: R.Tensor(("n",), "float32"), + y: R.Tensor(("n",), "float32"), ) -> R.Tensor(("n",), "float32"): return R.add(x, y) @@ -337,20 +431,20 @@ def add_relax( out5 = mod.add_relax(a5, b5) print("add_relax(len=5):", out5) - # Same module, now length 10 — no recompilation needed + # Same module, length 10 — no recompilation needed a10 = torch.randn(10) b10 = torch.randn(10) out10 = mod.add_relax(a10, b10) print("add_relax(len=10):", out10) - # call_tir with symbolic output shape + # call_tir also supports symbolic output shapes n = T.int64() x7 = torch.randn(7) - scaled = mod.call_tir("scale_tir", [x7], relax.TensorStructInfo((n,), "float32")) - print("scale_tir(len=7):", scaled) - assert torch.allclose( - torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5 + scaled = mod.call_tir( + "scale_tir", [x7], relax.TensorStructInfo((n,), "float32") ) + print("scale_tir(len=7):", scaled) + assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5) ###################################################################### @@ -358,31 +452,35 @@ def add_relax( # ------- # Here is what each step demonstrated and which PRs implement it: # -# +--------+----------------------------------------+---------------------+ -# | Step | What you learned | Key PRs | -# +========+========================================+=====================+ -# | 1 | ``@I.pyfunc`` + ``call_tir`` basics, | #18229, #18331 | -# | | DLPack conversion, ``show()`` | #18253 | -# +--------+----------------------------------------+---------------------+ -# | 2 | Realistic pipeline: Python preprocess | #18229 | -# | | → TIR kernel → Python fallback | | -# +--------+----------------------------------------+---------------------+ -# | 3 | ``add_python_function`` for runtime | #18229 | -# | | registration | | -# +--------+----------------------------------------+---------------------+ -# | 4 | ``RelaxToPyFuncConverter``: verify | #18269, #18301 | -# | | Relax IR numerically against PyTorch | | -# +--------+----------------------------------------+---------------------+ -# | 5 | ``R.call_py_func``: Python callbacks | #18313, #18326 | -# | | inside compiled Relax VM | | -# +--------+----------------------------------------+---------------------+ -# | 6 | Symbolic shapes for dynamic inputs | #18288 | -# +--------+----------------------------------------+---------------------+ +# +--------+-------------------------------------------+---------------------+ +# | Step | What you learned | Key PRs | +# +========+===========================================+=====================+ +# | 1 | ``@I.pyfunc`` + ``call_tir`` basics, | #18229, #18331 | +# | | DLPack conversion, ``show()`` | #18253 | +# +--------+-------------------------------------------+---------------------+ +# | 2 | Debugging with ``print`` in pyfuncs — | #18229 | +# | | no compilation needed | | +# +--------+-------------------------------------------+---------------------+ +# | 3 | ``call_tir`` + ``call_dps_packed`` + | #18229 | +# | | Python in one pipeline | | +# +--------+-------------------------------------------+---------------------+ +# | 4 | ``add_python_function`` for runtime | #18229 | +# | | registration | | +# +--------+-------------------------------------------+---------------------+ +# | 5 | ``RelaxToPyFuncConverter``: verify Relax | #18269, #18301 | +# | | IR at different compilation stages | | +# +--------+-------------------------------------------+---------------------+ +# | 6 | ``R.call_py_func``: cross-level calls | #18313, #18326 | +# | | between compiled VM and Python | | +# +--------+-------------------------------------------+---------------------+ +# | 7 | Symbolic shapes for dynamic inputs | #18288 | +# +--------+-------------------------------------------+---------------------+ # # The workflow in practice: # # 1. Import a model → some ops unsupported → use ``@I.pyfunc`` as Python fallbacks # 2. Get it running end-to-end with ``BasePyModule`` -# 3. Use ``RelaxToPyFuncConverter`` to verify correctness after optimization passes -# 4. Gradually replace Python fallbacks with TIR/Relax implementations -# 5. Use ``R.call_py_func`` for ops that must stay in Python even after compilation +# 3. Debug by inserting ``print`` in pyfuncs — inspect intermediate tensors instantly +# 4. Use ``RelaxToPyFuncConverter`` to verify correctness after each optimization pass +# 5. Gradually replace Python fallbacks with TIR/Relax implementations +# 6. Use ``R.call_py_func`` for ops that must stay in Python even after compilation From f6042e9838abb361dc2761dbbf6703498dfb8802 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 28 Mar 2026 01:37:29 -0400 Subject: [PATCH 5/6] fnish11 --- .../mix_python_and_tvm_with_pymodule.py | 60 +++++++++---------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index 02df6f2dd17e..841e8818d6ab 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -118,6 +118,9 @@ def forward(self, x, y): # show() prints TVMScript including Python functions (shown as ExternFunc) mod.show() + # list_functions() shows what is available in the module + print("Available functions:", mod.list_functions()) + ###################################################################### # Step 2: Debugging — The Main Selling Point @@ -398,11 +401,22 @@ def main( ###################################################################### -# Step 7: Symbolic Shapes — Dynamic Batch Sizes +# Step 7: Cross-Level Calls and Symbolic Shapes # ------------------------------------------------ -# Real models have dynamic shapes (e.g., variable batch size or sequence length). TIR and Relax -# functions can declare symbolic dimensions using string names like ``"n"``. ``BasePyModule`` -# automatically infers concrete shapes from the actual input tensors at call time. +# The RFC's core design is **cross-level interoperability**: Python functions can call TIR +# and Relax functions, and Relax functions can call Python functions. We have already seen: +# +# - Python → TIR via ``call_tir`` (Steps 1–3) +# - Python → packed function via ``call_dps_packed`` (Step 3) +# - Relax → Python via ``R.call_py_func`` (Step 6) +# +# The missing piece: **Python calling a compiled Relax function directly**. When a module +# contains ``@R.function``, it is JIT-compiled into a Relax VM. You can call it from Python +# just like any other method — the module auto-converts PyTorch tensors to TVM and back. +# +# This step also shows **symbolic shapes**: TIR and Relax functions can declare dynamic +# dimensions (e.g., ``"n"``). ``BasePyModule`` infers concrete shapes from the actual input +# tensors at call time, so the same module handles different sizes without recompilation. if RUN_EXAMPLE: @@ -425,19 +439,22 @@ def add_relax( mod = DynamicModule(device=tvm.cpu(0), target="llvm") - # Works with length 5 + # Inspect what the module contains + print("Functions:", mod.list_functions()) + + # Python → Relax: call the compiled Relax function directly with PyTorch tensors a5 = torch.randn(5) b5 = torch.randn(5) out5 = mod.add_relax(a5, b5) print("add_relax(len=5):", out5) - # Same module, length 10 — no recompilation needed + # Same module, different size — symbolic shapes handle this automatically a10 = torch.randn(10) b10 = torch.randn(10) out10 = mod.add_relax(a10, b10) print("add_relax(len=10):", out10) - # call_tir also supports symbolic output shapes + # Python → TIR with symbolic output shape n = T.int64() x7 = torch.randn(7) scaled = mod.call_tir( @@ -450,31 +467,12 @@ def add_relax( ###################################################################### # Summary # ------- -# Here is what each step demonstrated and which PRs implement it: +# Cross-level call summary (the RFC's core design): # -# +--------+-------------------------------------------+---------------------+ -# | Step | What you learned | Key PRs | -# +========+===========================================+=====================+ -# | 1 | ``@I.pyfunc`` + ``call_tir`` basics, | #18229, #18331 | -# | | DLPack conversion, ``show()`` | #18253 | -# +--------+-------------------------------------------+---------------------+ -# | 2 | Debugging with ``print`` in pyfuncs — | #18229 | -# | | no compilation needed | | -# +--------+-------------------------------------------+---------------------+ -# | 3 | ``call_tir`` + ``call_dps_packed`` + | #18229 | -# | | Python in one pipeline | | -# +--------+-------------------------------------------+---------------------+ -# | 4 | ``add_python_function`` for runtime | #18229 | -# | | registration | | -# +--------+-------------------------------------------+---------------------+ -# | 5 | ``RelaxToPyFuncConverter``: verify Relax | #18269, #18301 | -# | | IR at different compilation stages | | -# +--------+-------------------------------------------+---------------------+ -# | 6 | ``R.call_py_func``: cross-level calls | #18313, #18326 | -# | | between compiled VM and Python | | -# +--------+-------------------------------------------+---------------------+ -# | 7 | Symbolic shapes for dynamic inputs | #18288 | -# +--------+-------------------------------------------+---------------------+ +# - **Python → TIR**: ``call_tir()`` (Steps 1, 2, 3, 7) +# - **Python → packed function**: ``call_dps_packed()`` (Step 3) +# - **Python → Relax**: call ``@R.function`` as a method (Step 7) +# - **Relax → Python**: ``R.call_py_func()`` in compiled VM (Step 6) # # The workflow in practice: # From 48b8c145f7a992e99a58e0c659efc9124619101c Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 28 Mar 2026 01:43:04 -0400 Subject: [PATCH 6/6] fnish111 --- .../mix_python_and_tvm_with_pymodule.py | 46 ++++++------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index 841e8818d6ab..91d1cb9c2633 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -179,7 +179,7 @@ def forward(self, x, weights): assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5) ###################################################################### -# This is the workflow the RFC describes: "debugging is as simple as inserting a print statement. +# This is the key benefit: "debugging is as simple as inserting a print statement. # Users can also make quick, manual edits to Python functions and immediately observe the # results." No compilation cycle, no VM loading — just Python. @@ -258,29 +258,13 @@ def forward(self, x, weights, bias): ###################################################################### -# Step 4: Dynamic Function Registration -# ---------------------------------------- -# You can register Python functions after the module is created using ``add_python_function``. -# This is useful for swapping implementations at runtime — for example, testing different -# activation functions or registering a custom op. - -if RUN_EXAMPLE: - mod.add_python_function("gelu", lambda x: F.gelu(x)) - - x = torch.randn(4) - result = mod.gelu(x) - print("Dynamically registered gelu:", result) - assert torch.allclose(result, F.gelu(x)) - - -###################################################################### -# Step 5: Relax-to-Python Converter — Verify at Any Compilation Stage +# Step 4: Relax-to-Python Converter — Verify at Any Compilation Stage # ---------------------------------------------------------------------- # Both Relax functions and Python functions describe computational graphs. The # ``RelaxToPyFuncConverter`` converts Relax IR into equivalent PyTorch code by mapping # Relax operators to their PyTorch counterparts (e.g., ``R.nn.relu`` → ``F.relu``). # -# The key insight from the RFC: **this conversion can happen at any stage of compilation**. +# A key feature: **this conversion can happen at any stage of compilation**. # You can convert early (right after import) or late (after optimization passes have # transformed the IR), and compare the output against a PyTorch reference to catch bugs. @@ -346,15 +330,15 @@ def main( ###################################################################### -# Step 6: R.call_py_func — Python Callbacks in Compiled IR +# Step 5: R.call_py_func — Python Callbacks in Compiled IR # ----------------------------------------------------------- # ``R.call_py_func`` embeds a Python function call directly inside Relax IR. When the module # is compiled and run in the VM, everything else is optimized native code, but the VM calls # back into Python for the specified ops. # -# This is the "cross-level call" design from the RFC: Relax functions can invoke Python -# functions, and Python functions can invoke TIR/Relax functions. Data flows between them -# via DLPack with minimal overhead. +# ``BasePyModule`` supports cross-level calls in both directions: Relax functions can invoke +# Python functions, and Python functions can invoke TIR/Relax functions. Data flows between +# them via DLPack with minimal overhead. # # Use case: your model has a custom op (e.g., a special normalization or a sampling step) # that is complex to implement in TIR. Compile everything else, and let that one op stay @@ -401,14 +385,14 @@ def main( ###################################################################### -# Step 7: Cross-Level Calls and Symbolic Shapes +# Step 6: Cross-Level Calls and Symbolic Shapes # ------------------------------------------------ -# The RFC's core design is **cross-level interoperability**: Python functions can call TIR -# and Relax functions, and Relax functions can call Python functions. We have already seen: +# ``BasePyModule`` is designed for **cross-level interoperability**: Python functions can call +# TIR and Relax functions, and Relax functions can call Python functions. We have already seen: # # - Python → TIR via ``call_tir`` (Steps 1–3) # - Python → packed function via ``call_dps_packed`` (Step 3) -# - Relax → Python via ``R.call_py_func`` (Step 6) +# - Relax → Python via ``R.call_py_func`` (Step 5) # # The missing piece: **Python calling a compiled Relax function directly**. When a module # contains ``@R.function``, it is JIT-compiled into a Relax VM. You can call it from Python @@ -467,12 +451,12 @@ def add_relax( ###################################################################### # Summary # ------- -# Cross-level call summary (the RFC's core design): +# Cross-level call summary: # -# - **Python → TIR**: ``call_tir()`` (Steps 1, 2, 3, 7) +# - **Python → TIR**: ``call_tir()`` (Steps 1, 2, 3, 6) # - **Python → packed function**: ``call_dps_packed()`` (Step 3) -# - **Python → Relax**: call ``@R.function`` as a method (Step 7) -# - **Relax → Python**: ``R.call_py_func()`` in compiled VM (Step 6) +# - **Python → Relax**: call ``@R.function`` as a method (Step 6) +# - **Relax → Python**: ``R.call_py_func()`` in compiled VM (Step 5) # # The workflow in practice: #