Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ def __init__(self) -> None:

########## Utilities ##########

def update_convert_map(self, custom_convert_map: Dict[str, Callable]):
"""Update self.convert_map with custom convert map

Parameters
----------
custom_convert_map : Dict[str, Callable]
A custom op conversion map in the same format as self.convert_map
"""

self.convert_map.update(custom_convert_map)

@staticmethod
def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None):
"""converts the PyTorch scalar type input_type to a TVM dtype."""
Expand Down
26 changes: 22 additions & 4 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import fx
import tvm
from tvm import relax

Expand All @@ -32,8 +33,6 @@
class ExportedProgramImporter(BaseFXGraphImporter):
"""An importer from ExportedProgram to Relax."""

from torch import fx

@staticmethod
def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Tensor:
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
Expand Down Expand Up @@ -1604,9 +1603,18 @@ def from_exported_program(
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
custom_convert_map: Optional[
Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
],
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program."""
from torch import fx # type: ignore

# Update the conversion map with custom ops if provided.
if custom_convert_map:
custom_ops = set(custom_convert_map.keys())
self.update_convert_map(custom_convert_map)
else:
custom_ops = set()

# Create input variables.
(
Expand Down Expand Up @@ -1671,7 +1679,10 @@ def from_exported_program(
self.env[node] = getattr(exported_program.graph_module, node.target)
elif node.op == "call_function":
func_name = node.target.__name__
self.env[node] = self.convert_map[func_name](node)
if func_name in custom_ops:
self.env[node] = self.convert_map[func_name](node, self)
else:
self.env[node] = self.convert_map[func_name](node)
else:
raise ValueError(f"Unsupported op {node.op}")
assert output is not None
Expand Down Expand Up @@ -1711,6 +1722,9 @@ def from_exported_program(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: Optional[
Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
] = None,
run_ep_decomposition: bool = True,
) -> tvm.IRModule:
"""Convert a PyTorch ExportedProgram to a Relax program
Expand All @@ -1731,6 +1745,9 @@ def from_exported_program(
A boolean flag indicating whether to bind the return tuple as a relax var.
If the flag is true and the return value is a tuple, it will not bind it to a var.

custom_convert_map : Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]
A custom op conversion map in the same format as ExportedProgramImporter.convert_map above

run_ep_decomposition : bool
A boolean flag indicating whether to run PyTorch's decomposition on the
exported program before translation. When True, high-level operators will
Expand Down Expand Up @@ -1784,4 +1801,5 @@ def forward(self, input):
keep_params_as_input,
unwrap_unit_return_tuple,
no_bind_return_tuple,
custom_convert_map,
)
11 changes: 0 additions & 11 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,17 +1020,6 @@ def create_convert_map(
"item": self._item,
}

def update_convert_map(self, custom_convert_map: dict):
"""Update self.convert_map with custom convert map

Parameters
----------
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as self.convert_map
"""

self.convert_map.update(custom_convert_map)

def from_fx(
self,
model,
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def verify_model(
unwrap_unit_return_tuple=False,
no_bind_return_tuple=False,
map_free_vars=False,
custom_convert_map=None,
):
exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes)
mod = from_exported_program(
Expand All @@ -50,6 +51,7 @@ def verify_model(
keep_params_as_input=keep_params_as_input,
unwrap_unit_return_tuple=unwrap_unit_return_tuple,
no_bind_return_tuple=no_bind_return_tuple,
custom_convert_map=custom_convert_map,
)

binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
Expand Down Expand Up @@ -6525,6 +6527,40 @@ def forward(self, x):
from_exported_program(ep)


def test_custom_op():
class AddOp(Module):
def forward(self, x, y):
return torch.ops.aten.add.Tensor(x, y)

@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((5,), dtype="float32"),
y: R.Tensor((5,), dtype="float32"),
) -> R.Tuple(R.Tensor((5,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5,), dtype="float32") = R.subtract(x, y)
gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
R.output(gv)
return gv

from tvm.relax.frontend.torch.exported_program_translator import (
ExportedProgramImporter,
)

def custom_add_converter(node: torch.fx.Node, self: ExportedProgramImporter) -> relax.Var:
x = self.env[node.args[0]]
y = self.env[node.args[1]]

return self.block_builder.emit(R.subtract(x, y))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't catch here. Is this intended to be an addition or a subtraction? cc @mshr-h

Copy link
Copy Markdown
Contributor Author

@mshr-h mshr-h Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to test that the importer can override the default import behavior. By default, the importer converts add.Tensor into relax.add, but in the custom converter it is converted into relax.subtract. @tlopex


example_args = (torch.randn(5, dtype=torch.float32), torch.randn(5, dtype=torch.float32))
verify_model(
AddOp(), example_args, {}, Expected, custom_convert_map={"add.Tensor": custom_add_converter}
)


def test_empty_like():
class EmptyLike(Module):
def forward(self, data):
Expand Down
Loading