Skip to content

[Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend#18544

Merged
tlopex merged 2 commits intoapache:mainfrom
mshr-h:add-custom-converter
Dec 9, 2025
Merged

[Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend#18544
tlopex merged 2 commits intoapache:mainfrom
mshr-h:add-custom-converter

Conversation

@mshr-h
Copy link
Copy Markdown
Contributor

@mshr-h mshr-h commented Dec 4, 2025

As per title.

cc @tlopex @guan404ming

We keep the interface same as from_fx() so you can define and pass custom converter something like this.

from tvm.relax.frontend.torch.exported_program_translator import ExportedProgramImporter
def _rms_norm_converter(node: torch.fx.Node, self: ExportedProgramImporter) -> relax.Var:
    x = self.env[node.args[0]]
    torch_dtype = node.args[0].meta["tensor_meta"].dtype
    normalized_shape = node.args[1]
    weight = self.env.get(node.args[2], None) if len(node.args) > 2 else None
    eps = node.args[3] if len(node.args) > 3 else None

    N = len(self.shape_of(x))
    D = len(normalized_shape) if isinstance(normalized_shape, (tuple, list)) else 1
    axes = list(range(N - D, N))

    if weight is None:
        weight = self._convert_torch_tensor_to_relax(
            torch.ones(list(normalized_shape), dtype=torch_dtype)
        )
    eps = torch.finfo(torch_dtype).eps if eps is None else 0.00001

    return self.block_builder.emit(relax.op.nn.rms_norm(x, weight, axes, eps))

mod = from_exported_program(
    exported_program,
    custom_convert_map={"rms_norm.default": _rms_norm_converter},
    run_ep_decomposition=False,
)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @mshr-h, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the PyTorch ExportedProgram frontend for Relax by integrating comprehensive support for custom operators. It provides a mechanism for users to supply their own conversion functions for PyTorch operations that are not inherently handled, thereby expanding the compatibility and flexibility of the Relax framework when importing models from PyTorch's ExportedProgram. This change simplifies the process of incorporating custom or specialized PyTorch operations into the Relax compilation flow, making the system more extensible.

Highlights

  • Custom Operator Support: Introduced a new custom_convert_map parameter in the from_exported_program function, allowing users to define and provide custom conversion logic for PyTorch operators not natively supported by Relax.
  • Enhanced Custom Converter Flexibility: Custom conversion functions can now receive the BaseFXGraphImporter instance as an argument, enabling them to leverage importer utilities or internal state during the conversion process for more complex scenarios.
  • Unified Custom Op Handling: The update_convert_map method was moved to the BaseFXGraphImporter class, and the _check_unsupported_func_type method was updated to recognize custom operators. This ensures a consistent and robust approach to handling custom operations across different FX graph importers.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ability to provide custom operator conversion maps when importing PyTorch ExportedPrograms into Relax. Key changes include adding an update_convert_map method to BaseFXGraphImporter (and removing a redundant one from FXGraphImporter), and modifying ExportedProgramImporter's from_exported_program method to accept and apply a custom_convert_map. This map is used to extend the supported operations and to dispatch custom conversion logic, with the _check_unsupported_func_type method also updated to account for these custom mappings. Review feedback suggests correcting a type hint for custom_convert_map in from_exported_program to match the (node, importer) argument order, making the custom_convert_map type hint in update_convert_map more specific, and removing the redundant custom_convert_map parameter from _check_unsupported_func_type as self.convert_map is already updated prior to its call.

keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
custom_convert_map: Dict[str, Callable[[BaseFXGraphImporter, fx.Node], relax.Var]],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a mismatch between the type hint for custom_convert_map and how the callable is used. The type hint Callable[[BaseFXGraphImporter, fx.Node], relax.Var] implies a signature of (importer, node). However, the call at line 1680 is self.convert_map[func_name](node, self), which implies a signature of (node, importer). This will cause a runtime error due to mismatched argument types. Please correct the order of arguments in the type hint.

Suggested change
custom_convert_map: Dict[str, Callable[[BaseFXGraphImporter, fx.Node], relax.Var]],
custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]],


########## Utilities ##########

def update_convert_map(self, custom_convert_map: dict):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for custom_convert_map is dict, which is too generic. Please use a more specific type hint like Dict[str, Callable] for better type safety and clarity. The necessary types are already imported in this file.

Suggested change
def update_convert_map(self, custom_convert_map: dict):
def update_convert_map(self, custom_convert_map: Dict[str, Callable]):

Comment on lines 185 to 197
def _check_unsupported_func_type(self, nodes: List[fx.Node], custom_convert_map: Dict[str, Callable] = None):
supported_op_list = set(self.convert_map.keys())
if custom_convert_map:
supported_op_list = supported_op_list.union(custom_convert_map.keys())

missing_func_types = list(
{
node.target.__name__
for node in nodes
if node.op == "call_function" and node.target.__name__ not in self.convert_map
if node.op == "call_function" and node.target.__name__ not in supported_op_list
}
)
assert not missing_func_types, f"Unsupported function types {missing_func_types}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The custom_convert_map parameter is redundant here. The self.convert_map is updated with custom ops before this method is called in both fx_translator.py and exported_program_translator.py. You can simplify this method by removing the custom_convert_map parameter and directly using self.convert_map.

This change will also require updating the call site in python/tvm/relax/frontend/torch/exported_program_translator.py at line 1646 to self._check_unsupported_func_type(nodes).

    def _check_unsupported_func_type(self, nodes: List[fx.Node]):
        missing_func_types = list(
            {
                node.target.__name__
                for node in nodes
                if node.op == "call_function" and node.target.__name__ not in self.convert_map
            }
        )
        assert not missing_func_types, f"Unsupported function types {missing_func_types}"

@mshr-h mshr-h force-pushed the add-custom-converter branch from 32f69e1 to 78d38f4 Compare December 4, 2025 09:36
@mshr-h
Copy link
Copy Markdown
Contributor Author

mshr-h commented Dec 4, 2025

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for custom operators in the PyTorch ExportedProgram frontend, mirroring the existing functionality in the FX frontend. The changes introduce a custom_convert_map parameter to from_exported_program, allowing users to provide their own conversion functions for specific PyTorch operators. The implementation correctly handles these custom converters by passing the importer instance to them, enabling access to the translation context. Additionally, the update_convert_map method has been refactored into the base class BaseFXGraphImporter to reduce code duplication, which is a good improvement.

My review focuses on improving type hint correctness for the new custom_convert_map parameter to better reflect that it can be None.

keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The from_exported_program function, which calls this method, has a default value of None for custom_convert_map. To ensure type consistency, this parameter's type hint should be Optional[Dict[...]] to reflect that it can be None.

Suggested change
custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]],
custom_convert_map: Optional[Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]],

keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type hinting and to follow standard Python typing practices, Optional should be used here to indicate that custom_convert_map can be None.

Suggested change
custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] = None,
custom_convert_map: Optional[Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]] = None,

@mshr-h mshr-h force-pushed the add-custom-converter branch from 78d38f4 to 941a9c8 Compare December 4, 2025 09:57
@mshr-h mshr-h force-pushed the add-custom-converter branch from 941a9c8 to be1ad45 Compare December 4, 2025 10:24
@mshr-h mshr-h marked this pull request as ready for review December 5, 2025 03:42
@github-actions github-actions Bot requested a review from tlopex December 5, 2025 03:42
@tlopex tlopex self-assigned this Dec 5, 2025
Copy link
Copy Markdown
Member

@guan404ming guan404ming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

x = self.env[node.args[0]]
y = self.env[node.args[1]]

return self.block_builder.emit(R.subtract(x, y))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't catch here. Is this intended to be an addition or a subtraction? cc @mshr-h

Copy link
Copy Markdown
Contributor Author

@mshr-h mshr-h Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to test that the importer can override the default import behavior. By default, the importer converts add.Tensor into relax.add, but in the custom converter it is converted into relax.subtract. @tlopex

@tlopex tlopex merged commit 7271feb into apache:main Dec 9, 2025
19 of 20 checks passed
@mshr-h mshr-h deleted the add-custom-converter branch December 9, 2025 06:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants