diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 7ebb95c136f3..d5de1acf3d54 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -45,6 +45,17 @@ def __init__(self) -> None: ########## Utilities ########## + def update_convert_map(self, custom_convert_map: Dict[str, Callable]): + """Update self.convert_map with custom convert map + + Parameters + ---------- + custom_convert_map : Dict[str, Callable] + A custom op conversion map in the same format as self.convert_map + """ + + self.convert_map.update(custom_convert_map) + @staticmethod def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): """converts the PyTorch scalar type input_type to a TVM dtype.""" diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2ec61796c31a..54b60187e8d9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -23,6 +23,7 @@ from typing import Callable, Dict, List, Optional, Tuple import torch +from torch import fx import tvm from tvm import relax @@ -32,8 +33,6 @@ class ExportedProgramImporter(BaseFXGraphImporter): """An importer from ExportedProgram to Relax.""" - from torch import fx - @staticmethod def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Tensor: """Convert a PyTorch tensor to TVM tensor, handling sparse tensors. @@ -1604,9 +1603,18 @@ def from_exported_program( keep_params_as_input: bool, unwrap_unit_return_tuple: bool, no_bind_return_tuple: bool, + custom_convert_map: Optional[ + Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + ], ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program.""" - from torch import fx # type: ignore + + # Update the conversion map with custom ops if provided. + if custom_convert_map: + custom_ops = set(custom_convert_map.keys()) + self.update_convert_map(custom_convert_map) + else: + custom_ops = set() # Create input variables. ( @@ -1671,7 +1679,10 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - self.env[node] = self.convert_map[func_name](node) + if func_name in custom_ops: + self.env[node] = self.convert_map[func_name](node, self) + else: + self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") assert output is not None @@ -1711,6 +1722,9 @@ def from_exported_program( keep_params_as_input: bool = False, unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, + custom_convert_map: Optional[ + Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + ] = None, run_ep_decomposition: bool = True, ) -> tvm.IRModule: """Convert a PyTorch ExportedProgram to a Relax program @@ -1731,6 +1745,9 @@ def from_exported_program( A boolean flag indicating whether to bind the return tuple as a relax var. If the flag is true and the return value is a tuple, it will not bind it to a var. + custom_convert_map : Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] + A custom op conversion map in the same format as ExportedProgramImporter.convert_map above + run_ep_decomposition : bool A boolean flag indicating whether to run PyTorch's decomposition on the exported program before translation. When True, high-level operators will @@ -1784,4 +1801,5 @@ def forward(self, input): keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple, + custom_convert_map, ) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 9c2d53a68581..8a4b7d4bdbb1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1020,17 +1020,6 @@ def create_convert_map( "item": self._item, } - def update_convert_map(self, custom_convert_map: dict): - """Update self.convert_map with custom convert map - - Parameters - ---------- - custom_convert_map : Dictionary of str to Relax op - A custom op conversion map in the same format as self.convert_map - """ - - self.convert_map.update(custom_convert_map) - def from_fx( self, model, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 010bd026a8ba..0df1ce88ed15 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -42,6 +42,7 @@ def verify_model( unwrap_unit_return_tuple=False, no_bind_return_tuple=False, map_free_vars=False, + custom_convert_map=None, ): exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) mod = from_exported_program( @@ -50,6 +51,7 @@ def verify_model( keep_params_as_input=keep_params_as_input, unwrap_unit_return_tuple=unwrap_unit_return_tuple, no_bind_return_tuple=no_bind_return_tuple, + custom_convert_map=custom_convert_map, ) binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} @@ -6525,6 +6527,40 @@ def forward(self, x): from_exported_program(ep) +def test_custom_op(): + class AddOp(Module): + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((5,), dtype="float32"), + y: R.Tensor((5,), dtype="float32"), + ) -> R.Tuple(R.Tensor((5,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.subtract(x, y) + gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + from tvm.relax.frontend.torch.exported_program_translator import ( + ExportedProgramImporter, + ) + + def custom_add_converter(node: torch.fx.Node, self: ExportedProgramImporter) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + + return self.block_builder.emit(R.subtract(x, y)) + + example_args = (torch.randn(5, dtype=torch.float32), torch.randn(5, dtype=torch.float32)) + verify_model( + AddOp(), example_args, {}, Expected, custom_convert_map={"add.Tensor": custom_add_converter} + ) + + def test_empty_like(): class EmptyLike(Module): def forward(self, data):