[Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend#18544
[Relax][PyTorch] Add support for Custom Ops for ExportedProgram frontend#18544tlopex merged 2 commits intoapache:mainfrom
Conversation
Summary of ChangesHello @mshr-h, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the PyTorch ExportedProgram frontend for Relax by integrating comprehensive support for custom operators. It provides a mechanism for users to supply their own conversion functions for PyTorch operations that are not inherently handled, thereby expanding the compatibility and flexibility of the Relax framework when importing models from PyTorch's ExportedProgram. This change simplifies the process of incorporating custom or specialized PyTorch operations into the Relax compilation flow, making the system more extensible. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the ability to provide custom operator conversion maps when importing PyTorch ExportedPrograms into Relax. Key changes include adding an update_convert_map method to BaseFXGraphImporter (and removing a redundant one from FXGraphImporter), and modifying ExportedProgramImporter's from_exported_program method to accept and apply a custom_convert_map. This map is used to extend the supported operations and to dispatch custom conversion logic, with the _check_unsupported_func_type method also updated to account for these custom mappings. Review feedback suggests correcting a type hint for custom_convert_map in from_exported_program to match the (node, importer) argument order, making the custom_convert_map type hint in update_convert_map more specific, and removing the redundant custom_convert_map parameter from _check_unsupported_func_type as self.convert_map is already updated prior to its call.
| keep_params_as_input: bool, | ||
| unwrap_unit_return_tuple: bool, | ||
| no_bind_return_tuple: bool, | ||
| custom_convert_map: Dict[str, Callable[[BaseFXGraphImporter, fx.Node], relax.Var]], |
There was a problem hiding this comment.
There is a mismatch between the type hint for custom_convert_map and how the callable is used. The type hint Callable[[BaseFXGraphImporter, fx.Node], relax.Var] implies a signature of (importer, node). However, the call at line 1680 is self.convert_map[func_name](node, self), which implies a signature of (node, importer). This will cause a runtime error due to mismatched argument types. Please correct the order of arguments in the type hint.
| custom_convert_map: Dict[str, Callable[[BaseFXGraphImporter, fx.Node], relax.Var]], | |
| custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]], |
|
|
||
| ########## Utilities ########## | ||
|
|
||
| def update_convert_map(self, custom_convert_map: dict): |
There was a problem hiding this comment.
The type hint for custom_convert_map is dict, which is too generic. Please use a more specific type hint like Dict[str, Callable] for better type safety and clarity. The necessary types are already imported in this file.
| def update_convert_map(self, custom_convert_map: dict): | |
| def update_convert_map(self, custom_convert_map: Dict[str, Callable]): |
| def _check_unsupported_func_type(self, nodes: List[fx.Node], custom_convert_map: Dict[str, Callable] = None): | ||
| supported_op_list = set(self.convert_map.keys()) | ||
| if custom_convert_map: | ||
| supported_op_list = supported_op_list.union(custom_convert_map.keys()) | ||
|
|
||
| missing_func_types = list( | ||
| { | ||
| node.target.__name__ | ||
| for node in nodes | ||
| if node.op == "call_function" and node.target.__name__ not in self.convert_map | ||
| if node.op == "call_function" and node.target.__name__ not in supported_op_list | ||
| } | ||
| ) | ||
| assert not missing_func_types, f"Unsupported function types {missing_func_types}" |
There was a problem hiding this comment.
The custom_convert_map parameter is redundant here. The self.convert_map is updated with custom ops before this method is called in both fx_translator.py and exported_program_translator.py. You can simplify this method by removing the custom_convert_map parameter and directly using self.convert_map.
This change will also require updating the call site in python/tvm/relax/frontend/torch/exported_program_translator.py at line 1646 to self._check_unsupported_func_type(nodes).
def _check_unsupported_func_type(self, nodes: List[fx.Node]):
missing_func_types = list(
{
node.target.__name__
for node in nodes
if node.op == "call_function" and node.target.__name__ not in self.convert_map
}
)
assert not missing_func_types, f"Unsupported function types {missing_func_types}"32f69e1 to
78d38f4
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for custom operators in the PyTorch ExportedProgram frontend, mirroring the existing functionality in the FX frontend. The changes introduce a custom_convert_map parameter to from_exported_program, allowing users to provide their own conversion functions for specific PyTorch operators. The implementation correctly handles these custom converters by passing the importer instance to them, enabling access to the translation context. Additionally, the update_convert_map method has been refactored into the base class BaseFXGraphImporter to reduce code duplication, which is a good improvement.
My review focuses on improving type hint correctness for the new custom_convert_map parameter to better reflect that it can be None.
| keep_params_as_input: bool, | ||
| unwrap_unit_return_tuple: bool, | ||
| no_bind_return_tuple: bool, | ||
| custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]], |
There was a problem hiding this comment.
The from_exported_program function, which calls this method, has a default value of None for custom_convert_map. To ensure type consistency, this parameter's type hint should be Optional[Dict[...]] to reflect that it can be None.
| custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]], | |
| custom_convert_map: Optional[Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]], |
| keep_params_as_input: bool = False, | ||
| unwrap_unit_return_tuple: bool = False, | ||
| no_bind_return_tuple: bool = False, | ||
| custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] = None, |
There was a problem hiding this comment.
For better type hinting and to follow standard Python typing practices, Optional should be used here to indicate that custom_convert_map can be None.
| custom_convert_map: Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]] = None, | |
| custom_convert_map: Optional[Dict[str, Callable[[fx.Node, BaseFXGraphImporter], relax.Var]]] = None, |
78d38f4 to
941a9c8
Compare
941a9c8 to
be1ad45
Compare
| x = self.env[node.args[0]] | ||
| y = self.env[node.args[1]] | ||
|
|
||
| return self.block_builder.emit(R.subtract(x, y)) |
There was a problem hiding this comment.
I didn't catch here. Is this intended to be an addition or a subtraction? cc @mshr-h
There was a problem hiding this comment.
I wanted to test that the importer can override the default import behavior. By default, the importer converts add.Tensor into relax.add, but in the custom converter it is converted into relax.subtract. @tlopex
As per title.
cc @tlopex @guan404ming
We keep the interface same as
from_fx()so you can define and pass custom converter something like this.