From 498f374e88cd25dc650853bc1cef441212854652 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 4 Sep 2025 21:41:17 +0800 Subject: [PATCH 1/4] finish1 --- .../relax_to_pyfunc_conversion_refactored.py | 246 ++++ python/tvm/relax/relax_to_pyfunc_converter.py | 1070 +++++++++++++++++ .../relax/test_relax_to_pyfunc_converter.py | 813 +++++++++++++ 3 files changed, 2129 insertions(+) create mode 100644 examples/relax_to_pyfunc_conversion_refactored.py create mode 100644 python/tvm/relax/relax_to_pyfunc_converter.py create mode 100644 tests/python/relax/test_relax_to_pyfunc_converter.py diff --git a/examples/relax_to_pyfunc_conversion_refactored.py b/examples/relax_to_pyfunc_conversion_refactored.py new file mode 100644 index 000000000000..8d1091b9d130 --- /dev/null +++ b/examples/relax_to_pyfunc_conversion_refactored.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Example: Converting Relax Functions to Python Functions (Refactored) + +This example demonstrates the new refactored architecture for converting Relax functions +to Python functions. The key improvement is that the converter now works with pure +IRModule objects, making it more modular and reusable. + +Key Features: +1. Pure IRModule conversion (no BasePyModule dependency) +2. Independent converter class +3. Convenience function for direct usage +4. BasePyModule integration for backward compatibility +""" + +import tvm +from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter, convert_relax_to_pyfunc +from tvm.relax.base_py_module import BasePyModule +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R + + +@I.ir_module +class ExampleModule(BasePyModule): + """Example module with various Relax functions for conversion.""" + + @T.prim_func + def custom_add(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """Custom TIR function for addition.""" + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + + for i in range(5): + out[i] = x[i] + y[i] + + @R.function + def simple_math( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + """Simple mathematical operations.""" + # Basic arithmetic + add_result = R.add(x, y) + multiply_result = R.multiply(add_result, R.const(2.0, "float32")) + return multiply_result + + @R.function + def neural_network_layer( + x: R.Tensor((10, 20), "float32"), + weight: R.Tensor((20, 10), "float32"), + bias: R.Tensor((10,), "float32") + ) -> R.Tensor((10, 10), "float32"): + """Neural network layer with linear transformation and activation.""" + # Linear transformation + linear_out = R.matmul(x, weight) + # Add bias + biased_out = R.add(linear_out, bias) + # Apply ReLU activation + activated_out = R.nn.relu(biased_out) + return activated_out + + @R.function + def with_tir_call( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + """Function that calls a TIR function.""" + return R.call_tir(custom_add, (x, y), out_sinfo=R.Tensor((5,), "float32")) + + @R.function + def with_conditionals( + x: R.Tensor((5,), "float32"), threshold: R.Tensor((), "float32") + ) -> R.Tensor((5,), "float32"): + """Function with conditional logic.""" + # Create condition + condition = R.greater(x, threshold) + # True branch: multiply by 2 + true_result = R.multiply(x, R.const(2.0, "float32")) + # False branch: multiply by 0.5 + false_result = R.multiply(x, R.const(0.5, "float32")) + # Apply conditional + return R.where(condition, true_result, false_result) + + +def demo_pure_converter(): + """Demonstrate the pure converter approach.""" + print("=" * 80) + print("1. Pure Converter Approach (Recommended)") + print("=" * 80) + + ir_mod = ExampleModule + + # Create converter directly from IRModule + converter = RelaxToPyFuncConverter(ir_mod) + + print("\nConverting individual functions:") + print("-" * 50) + + # Convert single function + converted_ir_mod = converter.convert("simple_math") + simple_math_func = converted_ir_mod.pyfuncs["simple_math"] + result = simple_math_func("arg1", "arg2") + print(f"simple_math: {result}") + + # Convert multiple functions + converted_ir_mod = converter.convert(["neural_network_layer", "with_conditionals"]) + print(f"Converted functions: {list(converted_ir_mod.pyfuncs.keys())}") + + # Test converted functions + for func_name, func in converted_ir_mod.pyfuncs.items(): + if func_name == "neural_network_layer": + result = func("input", "weight", "bias") + elif func_name == "with_conditionals": + result = func("x", "threshold") + else: + result = func("arg1", "arg2") + print(f"{func_name}: {result}") + + +def demo_convenience_function(): + """Demonstrate the convenience function approach.""" + print("\n" + "=" * 80) + print("2. Convenience Function Approach") + print("=" * 80) + + ir_mod = ExampleModule + + print("\nUsing convenience function:") + print("-" * 50) + + # Convert using convenience function + converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "simple_math") + simple_math_func = converted_ir_mod.pyfuncs["simple_math"] + result = simple_math_func("arg1", "arg2") + print(f"simple_math: {result}") + + # Convert multiple functions + converted_ir_mod = convert_relax_to_pyfunc(ir_mod, [ + "simple_math", + "neural_network_layer", + "with_conditionals" + ]) + print(f"Converted functions: {list(converted_ir_mod.pyfuncs.keys())}") + + +def demo_basepymodule_integration(): + """Demonstrate BasePyModule integration for backward compatibility.""" + print("\n" + "=" * 80) + print("3. BasePyModule Integration (Backward Compatibility)") + print("=" * 80) + + ir_mod = ExampleModule + device = tvm.cpu() + module = BasePyModule(ir_mod, device) + + print("\nUsing BasePyModule method:") + print("-" * 50) + + # Convert using BasePyModule method + converted_module = module.convert_relax_to_pyfunc("simple_math") + simple_math_func = converted_module.ir_mod.pyfuncs["simple_math"] + result = simple_math_func("arg1", "arg2") + print(f"simple_math: {result}") + + # Convert multiple functions + converted_module = module.convert_relax_to_pyfunc([ + "simple_math", + "neural_network_layer" + ]) + print(f"Converted functions: {list(converted_module.ir_mod.pyfuncs.keys())}") + + +def demo_operator_mapping(): + """Demonstrate the operator mapping functionality.""" + print("\n" + "=" * 80) + print("4. Operator Mapping") + print("=" * 80) + + ir_mod = ExampleModule + converter = RelaxToPyFuncConverter(ir_mod) + operator_map = converter.operator_map + + print(f"\nTotal operators mapped: {len(operator_map)}") + + # Show some key mappings + key_operators = [ + "relax.add", "relax.multiply", "relax.matmul", + "relax.nn.relu", "relax.nn.softmax", "relax.where" + ] + + print("\nKey operator mappings:") + for relax_op in key_operators: + if relax_op in operator_map: + pytorch_op = operator_map[relax_op] + print(f" {relax_op} -> {pytorch_op}") + + +def demo_architecture_benefits(): + """Demonstrate the benefits of the new architecture.""" + print("\n" + "=" * 80) + print("5. Architecture Benefits") + print("=" * 80) + + print("\nBenefits of the refactored architecture:") + print("-" * 50) + print("✓ Pure IRModule conversion (no BasePyModule dependency)") + print("✓ Independent converter class for reusability") + print("✓ Convenience function for direct usage") + print("✓ BasePyModule integration for backward compatibility") + print("✓ Better separation of concerns") + print("✓ Easier testing and maintenance") + print("✓ More modular design") + + print("\nUsage patterns:") + print("-" * 50) + print("# Pure converter (recommended)") + print("converter = RelaxToPyFuncConverter(ir_mod)") + print("converted_ir_mod = converter.convert(['func1', 'func2'])") + print() + print("# Convenience function") + print("converted_ir_mod = convert_relax_to_pyfunc(ir_mod, 'func1')") + print() + print("# BasePyModule integration") + print("module = BasePyModule(ir_mod, device)") + print("converted_module = module.convert_relax_to_pyfunc('func1')") + + +def main(): + """Main function demonstrating the refactored conversion process.""" + print("Relax to Python Function Conversion (Refactored Architecture)") + print("=" * 80) + + # Run all demos + demo_pure_converter() + demo_convenience_function() + demo_basepymodule_integration() + demo_operator_mapping() + demo_architecture_benefits() + + print("\n" + "=" * 80) + print("Refactored architecture demo completed successfully!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..2a26a35ee502 --- /dev/null +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -0,0 +1,1070 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax to Python Function Converter. + +This module provides functionality to convert Relax functions to Python functions +that can be executed directly in Python/PyTorch environment. +""" + +from typing import Any, Dict, List, Union + +import tvm +from tvm import relax +from tvm.ir import IRModule, Op +import torch +import torch.nn.functional as F + + +class RelaxToPyFuncConverter: + """Converter that works with IRModule to convert Relax functions to Python functions. + + This converter transforms Relax functions into Python functions that can be executed + directly in Python/PyTorch environment. The conversion maps Relax operators to + corresponding PyTorch APIs and handles special cases like call_tir and call_dps_packed. + """ + + def __init__(self, ir_module: IRModule): + """Initialize the converter with an IRModule. + + Args: + ir_module: The IRModule containing Relax functions to convert + """ + self.ir_module = ir_module + self.operator_map = self._get_relax_to_pytorch_operator_map() + + def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: + """Convert specified Relax functions to Python functions. + + Args: + relax_function_names: Name(s) of Relax functions to convert + + Returns: + Updated IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converter = RelaxToPyFuncConverter(ir_mod) + >>> # Convert a single function + >>> converted_ir_mod = converter.convert("my_relax_func") + >>> # Convert multiple functions + >>> converted_ir_mod = converter.convert(["func1", "func2"]) + """ + if isinstance(relax_function_names, str): + relax_function_names = [relax_function_names] + + # Create a copy of the current IRModule + new_ir_mod = self.ir_module.clone() + + # Initialize pyfuncs if not exists + if not hasattr(new_ir_mod, "pyfuncs"): + new_ir_mod.pyfuncs = {} + + # Get Relax function names from IRModule + relax_func_names = [] + for gv, func in self.ir_module.functions_items(): + if isinstance(func, relax.Function): + relax_func_names.append(gv.name_hint) + + # Convert each Relax function + for func_name in relax_function_names: + if func_name not in relax_func_names: + raise ValueError(f"Relax function '{func_name}' not found in IRModule") + + # Get the Relax function + relax_func = None + for gv, func in self.ir_module.functions_items(): + if gv.name_hint == func_name and isinstance(func, relax.Function): + relax_func = func + break + + if relax_func is None: + raise ValueError(f"Could not find Relax function '{func_name}'") + + # Convert to Python function + py_func = self._convert_relax_function_to_python(relax_func, func_name) + + # Store in pyfuncs + new_ir_mod.pyfuncs[func_name] = py_func + + return new_ir_mod + + def _convert_relax_function_to_python(self, relax_func: relax.Function, func_name: str) -> callable: + """Convert a single Relax function to a Python function.""" + # Get function parameters + params = relax_func.params + + # Create the Python function + def converted_function(*args, **kwargs): + """Converted Python function from Relax function.""" + # Handle arguments + if len(args) != len(params): + raise ValueError(f"Expected {len(params)} arguments, got {len(args)}") + + # Execute the converted function body + converter = RelaxExpressionConverter(self.operator_map, self.ir_module) + converter.current_params = params + return converter.convert_expr(relax_func.body, args) + + # Set function metadata + converted_function.__name__ = func_name + converted_function.__doc__ = f"Converted Python function from Relax function: {func_name}" + + return converted_function + + @staticmethod + def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: + """Get the mapping from Relax operators to PyTorch operators.""" + return { + # Binary operations + "relax.add": "torch.add", + "relax.subtract": "torch.sub", + "relax.multiply": "torch.mul", + "relax.divide": "torch.div", + "relax.power": "torch.pow", + "relax.maximum": "torch.maximum", + "relax.minimum": "torch.minimum", + "relax.floor_divide": "torch.floor_divide", + "relax.mod": "torch.fmod", + "relax.floor_mod": "torch.remainder", + "relax.log_add_exp": "torch.logaddexp", + + # Bitwise operations + "relax.bitwise_and": "torch.bitwise_and", + "relax.bitwise_or": "torch.bitwise_or", + "relax.bitwise_xor": "torch.bitwise_xor", + "relax.left_shift": "torch.left_shift", + "relax.right_shift": "torch.right_shift", + + # Unary operations + "relax.abs": "torch.abs", + "relax.negative": "torch.neg", + "relax.exp": "torch.exp", + "relax.log": "torch.log", + "relax.sqrt": "torch.sqrt", + "relax.rsqrt": "torch.rsqrt", + "relax.sin": "torch.sin", + "relax.cos": "torch.cos", + "relax.tanh": "torch.tanh", + "relax.sigmoid": "torch.sigmoid", + "relax.square": "torch.square", + "relax.sign": "torch.sign", + "relax.floor": "torch.floor", + "relax.ceil": "torch.ceil", + "relax.round": "torch.round", + "relax.trunc": "torch.trunc", + "relax.clip": "torch.clamp", + "relax.bitwise_not": "torch.bitwise_not", + + # Trigonometric functions + "relax.acos": "torch.acos", + "relax.asin": "torch.asin", + "relax.atan": "torch.atan", + "relax.cosh": "torch.cosh", + "relax.sinh": "torch.sinh", + "relax.tan": "torch.tan", + "relax.acosh": "torch.acosh", + "relax.asinh": "torch.asinh", + "relax.atanh": "torch.atanh", + + # Special functions + "relax.erf": "torch.erf", + "relax.isfinite": "torch.isfinite", + "relax.isinf": "torch.isinf", + "relax.isnan": "torch.isnan", + + # Neural network operations + "relax.nn.relu": "F.relu", + "relax.nn.relu6": "F.relu6", + "relax.nn.gelu": "F.gelu", + "relax.nn.gelu_tanh": "F.gelu", + "relax.nn.softmax": "F.softmax", + "relax.nn.log_softmax": "F.log_softmax", + "relax.nn.dropout": "F.dropout", + "relax.nn.batch_norm": "F.batch_norm", + "relax.nn.layer_norm": "F.layer_norm", + "relax.nn.group_norm": "F.group_norm", + "relax.nn.instance_norm": "F.instance_norm", + "relax.nn.rms_norm": "F.layer_norm", # Approximate mapping + "relax.nn.linear": "F.linear", + "relax.nn.conv1d": "F.conv1d", + "relax.nn.conv2d": "F.conv2d", + "relax.nn.conv3d": "F.conv3d", + "relax.nn.conv1d_transpose": "F.conv_transpose1d", + "relax.nn.conv2d_transpose": "F.conv_transpose2d", + "relax.nn.conv3d_transpose": "F.conv_transpose3d", + "relax.nn.max_pool1d": "F.max_pool1d", + "relax.nn.max_pool2d": "F.max_pool2d", + "relax.nn.max_pool3d": "F.max_pool3d", + "relax.nn.avg_pool1d": "F.avg_pool1d", + "relax.nn.avg_pool2d": "F.avg_pool2d", + "relax.nn.avg_pool3d": "F.avg_pool3d", + "relax.nn.adaptive_avg_pool1d": "F.adaptive_avg_pool1d", + "relax.nn.adaptive_avg_pool2d": "F.adaptive_avg_pool2d", + "relax.nn.adaptive_avg_pool3d": "F.adaptive_avg_pool3d", + "relax.nn.leakyrelu": "F.leaky_relu", + "relax.nn.prelu": "F.prelu", + "relax.nn.selu": "F.selu", + "relax.nn.silu": "F.silu", + "relax.nn.softplus": "F.softplus", + "relax.nn.attention": "F.scaled_dot_product_attention", # Approximate mapping + "relax.nn.cross_entropy_with_logits": "F.cross_entropy", + "relax.nn.nll_loss": "F.nll_loss", + "relax.nn.pad": "F.pad", + "relax.nn.pixel_shuffle": "F.pixel_shuffle", + + # Tensor operations + "relax.matmul": "torch.matmul", + "relax.linear": "F.linear", + "relax.einsum": "torch.einsum", + "relax.outer": "torch.outer", + "relax.reshape": "reshape", # Special handling needed + "relax.permute_dims": "permute_dims", # Special handling needed + "relax.expand_dims": "expand_dims", # Special handling needed + "relax.squeeze": "squeeze", # Special handling needed + "relax.concat": "concat", # Special handling needed + "relax.split": "split", # Special handling needed + "relax.stack": "stack", # Special handling needed + "relax.tile": "tile", # Special handling needed + "relax.repeat": "repeat", # Special handling needed + "relax.broadcast_to": "torch.broadcast_to", + "relax.flatten": "torch.flatten", + "relax.flip": "flip", # Special handling needed + "relax.roll": "torch.roll", + "relax.rot90": "torch.rot90", + "relax.meshgrid": "torch.meshgrid", + "relax.one_hot": "F.one_hot", + "relax.layout_transform": "torch.permute", # Approximate mapping + + # Indexing operations + "relax.take": "take", # Special handling needed + "relax.gather_elements": "torch.gather", + "relax.gather_nd": "torch.gather", + "relax.scatter_elements": "torch.scatter", + "relax.scatter_nd": "torch.scatter", + "relax.index_put": "torch.index_put", + "relax.index_tensor": "torch.index_select", + "relax.strided_slice": "torch.slice", + "relax.dynamic_strided_slice": "torch.slice", + "relax.slice_scatter": "torch.scatter", + + # Reduction operations + "relax.sum": "sum", # Special handling needed + "relax.mean": "mean", # Special handling needed + "relax.max": "max", # Special handling needed + "relax.min": "min", # Special handling needed + "relax.prod": "torch.prod", + "relax.std": "std", # Special handling needed + "relax.variance": "variance", # Special handling needed + "relax.cumsum": "torch.cumsum", + "relax.cumprod": "torch.cumprod", + "relax.argmax": "torch.argmax", + "relax.argmin": "torch.argmin", + + # Comparison operations + "relax.equal": "torch.eq", + "relax.not_equal": "torch.ne", + "relax.greater": "torch.gt", + "relax.greater_equal": "torch.ge", + "relax.less": "torch.lt", + "relax.less_equal": "torch.le", + + # Logical operations + "relax.logical_and": "torch.logical_and", + "relax.logical_or": "torch.logical_or", + "relax.logical_not": "torch.logical_not", + "relax.logical_xor": "torch.logical_xor", + + # Creation operations + "relax.zeros": "torch.zeros", + "relax.ones": "torch.ones", + "relax.full": "torch.full", + "relax.full_like": "torch.full_like", + "relax.zeros_like": "torch.zeros_like", + "relax.ones_like": "torch.ones_like", + "relax.arange": "torch.arange", + "relax.eye": "torch.eye", + "relax.eye_like": "torch.eye", + "relax.tril": "torch.tril", + "relax.triu": "torch.triu", + "relax.hamming_window": "torch.hamming_window", + + # Search operations + "relax.where": "torch.where", + "relax.bucketize": "torch.bucketize", + "relax.nonzero": "torch.nonzero", + "relax.unique": "torch.unique", + + # Sorting operations + "relax.sort": "torch.sort", + "relax.argsort": "torch.argsort", + "relax.topk": "torch.topk", + + # Sampling operations + "relax.multinomial_from_uniform": "torch.multinomial", + + # Ternary operations + "relax.ewise_fma": "torch.fma", # Approximate mapping + + # Data type operations + "relax.astype": "torch.to", + "relax.wrap_param": "torch.tensor", + + # Mask operations + "relax.masked_fill": "torch.masked_fill", + + # Quantization operations + "relax.quantize": "torch.quantize_per_tensor", # Approximate mapping + "relax.dequantize": "torch.dequantize", # Approximate mapping + + # Special operations (handled separately) + "relax.call_tir": "call_tir", + "relax.call_tir_inplace": "call_tir_inplace", + "relax.call_dps_packed": "call_dps_packed", + "relax.call_pure_packed": "call_pure_packed", + "relax.call_tir_with_grad": "call_tir_with_grad", + "relax.call_builtin_with_ctx": "call_builtin_with_ctx", + "relax.call_inplace_packed": "call_inplace_packed", + "relax.invoke_closure": "invoke_closure", + "relax.invoke_pure_closure": "invoke_pure_closure", + "relax.make_closure": "make_closure", + "relax.null_value": "null_value", + "relax.print": "print", + "relax.shape_of": "shape_of", + "relax.shape_to_tensor": "shape_to_tensor", + "relax.tensor_to_shape": "tensor_to_shape", + "relax.to_vdevice": "to_vdevice", + "relax.hint_on_device": "hint_on_device", + "relax.assert_op": "assert_op", + } + + +class RelaxExpressionConverter: + """Converter that transforms Relax expressions to Python/PyTorch code.""" + + def __init__(self, operator_map: Dict[str, str], ir_module: IRModule = None): + """Initialize the expression converter. + + Args: + operator_map: Mapping from Relax operators to PyTorch operators + ir_module: The IRModule containing TIR functions to compile + """ + self.operator_map = operator_map + self.variable_map: Dict[str, Any] = {} + self.current_params: List[relax.Var] = [] + self.ir_module = ir_module + + def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any: + """Convert a Relax expression to Python/PyTorch equivalent.""" + if isinstance(expr, relax.Var): + return self._convert_var(expr, args) + elif isinstance(expr, relax.Call): + return self._convert_call(expr, args) + elif isinstance(expr, relax.Constant): + return self._convert_constant(expr) + elif isinstance(expr, relax.SeqExpr): + return self._convert_seq_expr(expr, args) + elif isinstance(expr, relax.Tuple): + return self._convert_tuple(expr, args) + elif isinstance(expr, relax.TupleGetItem): + return self._convert_tuple_get_item(expr, args) + elif isinstance(expr, relax.If): + return self._convert_if(expr, args) + elif isinstance(expr, relax.ShapeExpr): + return self._convert_shape_expr(expr) + else: + # Fallback for unknown expression types + return f"" + + def _convert_var(self, var: relax.Var, args: List[Any]) -> Any: + """Convert a Relax variable to Python equivalent.""" + if hasattr(var, 'name_hint'): + var_name = var.name_hint + + # Check if it's a function parameter + for i, param in enumerate(self.current_params): + if hasattr(param, 'name_hint') and param.name_hint == var_name: + return args[i] + + # Check if it's a bound variable + if var_name in self.variable_map: + return self.variable_map[var_name] + + # Return placeholder for unbound variables + return f"" + return f"" + + def _convert_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax call to Python/PyTorch equivalent.""" + op = call.op + + # Handle different types of calls + if isinstance(op, relax.GlobalVar): + # Function call + return self._convert_function_call(call, args) + elif isinstance(op, Op): + # Operator call + return self._convert_operator_call(call, args) + elif isinstance(op, relax.ExternFunc): + # External function call (like call_tir, call_dps_packed) + return self._convert_extern_func_call(call, args) + else: + return f"" + + def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax function call.""" + func_name = call.op.name_hint + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Handle special cases + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + # Regular function call + return f"" + + def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax operator call to PyTorch equivalent.""" + op_name = call.op.name + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Get PyTorch equivalent + pytorch_op = self.operator_map.get(op_name) + if pytorch_op: + try: + # Handle special operations + if pytorch_op == "call_tir": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_tir_inplace": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_dps_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "call_pure_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "expand_dims": + return self._convert_expand_dims(call, args) + elif pytorch_op in ["sum", "mean", "max", "min", "std", "variance"]: + return self._convert_reduction_op(call, args, pytorch_op) + elif pytorch_op == "squeeze": + return self._convert_squeeze(call, args) + elif pytorch_op in ["concat", "split", "stack"]: + return self._convert_tensor_ops(call, args, pytorch_op) + elif pytorch_op == "reshape": + return self._convert_reshape(call, args) + elif pytorch_op == "permute_dims": + return self._convert_permute_dims(call, args) + elif pytorch_op == "take": + return self._convert_take(call, args) + elif pytorch_op == "flip": + return self._convert_flip(call, args) + elif pytorch_op == "tile": + return self._convert_tile(call, args) + elif pytorch_op == "repeat": + return self._convert_repeat(call, args) + # Handle special cases for PyTorch operations + elif pytorch_op.startswith("F."): + # Neural network function + func_name = pytorch_op[2:] # Remove "F." prefix + func = getattr(F, func_name) + + # Special handling for functions that need dim parameter + if func_name in ["softmax", "log_softmax"]: + # Extract axis from call.attrs and convert to dim + axis = None + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return func(call_args[0], dim=axis) + else: + # Default to last dimension if no axis specified + return func(call_args[0], dim=-1) + else: + return func(*call_args) + elif pytorch_op.startswith("torch."): + # Regular PyTorch operation + func_name = pytorch_op[6:] # Remove "torch." prefix + func = getattr(torch, func_name) + return func(*call_args) + else: + # Direct function reference + return eval(pytorch_op)(*call_args) + except Exception as e: + # This allows the test framework to catch and handle the errors appropriately + if pytorch_op.startswith("torch.") or pytorch_op.startswith("F."): + raise e + else: + # Fallback to string representation for non-PyTorch operations + return f"" + else: + # Unknown operator + return f"" + + def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert an external function call.""" + func_name = call.op.global_symbol + call_args = [self.convert_expr(arg, args) for arg in call.args] + + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + return f"" + + def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_tir to Python equivalent with DLPack conversion.""" + # Extract TIR function name and arguments + tir_func = call.args[0] + tir_args = call.args[1] if len(call.args) > 1 else [] + out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(tir_func, relax.GlobalVar): + func_name = tir_func.name_hint + else: + # Convert the GlobalVar expression + func_name = self.convert_expr(tir_func, args) + if isinstance(func_name, str) and func_name.startswith("<"): + # If it's a placeholder, extract the name + func_name = str(tir_func) + + # Convert arguments to PyTorch tensors + converted_args = [self.convert_expr(arg, args) for arg in tir_args] + + try: + # First, try to get the TIR function from the current IRModule + tir_function = None + if self.ir_module: + # Look for the TIR function in the current IRModule + for gv, func in self.ir_module.functions.items(): + if gv.name_hint == func_name and hasattr(func, 'body'): + try: + # Compile the TIR function + target = tvm.target.Target("llvm") + with tvm.target.Target(target): + tir_function = tvm.compile(func, target=target) + break + except Exception as compile_e: + print(f"Warning: Failed to compile TIR function {func_name}: {compile_e}") + continue + + # If not found in current module, try global registry + if tir_function is None: + tir_function = tvm.get_global_func(func_name) + + if tir_function is None: + return f"" + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + + # For call_tir, we need to allocate output tensor + output_shape = None + if out_sinfo and hasattr(out_sinfo, 'shape'): + output_shape = out_sinfo.shape + elif converted_args: + # Use the shape of the first input tensor + first_arg = converted_args[0] + if isinstance(first_arg, torch.Tensor): + output_shape = first_arg.shape + + if output_shape is None: + return f"" + + # Allocate output tensor + output_tensor = tvm.nd.array(tvm.nd.empty(output_shape, dtype="float32")) + tvm_args.append(output_tensor) + + # Call the TIR function + tir_function(*tvm_args) + + # The result is in the output_tensor we allocated + # Convert result back to PyTorch tensor via DLPack + return torch.from_dlpack(output_tensor.to_dlpack()) + + except Exception as e: + return f"" + + def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_dps_packed to Python equivalent with DLPack conversion.""" + # Extract packed function name and arguments + packed_func = call.args[0] + packed_args = call.args[1] if len(call.args) > 1 else [] + out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(packed_func, relax.GlobalVar): + func_name = packed_func.name_hint + elif isinstance(packed_func, relax.ExternFunc): + func_name = packed_func.global_symbol + else: + func_name = str(packed_func) + + # Convert arguments to PyTorch tensors + converted_args = [self.convert_expr(arg, args) for arg in packed_args] + + try: + # Get the packed function from TVM + packed_function = tvm.get_global_func(func_name) + if packed_function is None: + return f"" + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + + # Call the packed function + result = packed_function(*tvm_args) + + # Convert result back to PyTorch tensor via DLPack + if isinstance(result, tvm.nd.NDArray): + return torch.from_dlpack(result.to_dlpack()) + else: + return result + + except Exception as e: + return f"" + + def _convert_constant(self, const: relax.Constant) -> Any: + """Convert a Relax constant to Python equivalent.""" + if hasattr(const, 'data'): + data = const.data + # Convert TVM NDArray to Python scalar if it's a scalar + if hasattr(data, 'numpy'): + numpy_data = data.numpy() + if numpy_data.size == 1: + return float(numpy_data.item()) + else: + # For multi-element arrays, convert to PyTorch tensor + return torch.from_numpy(numpy_data) + elif hasattr(data, 'item'): + # Single element tensor + return data.item() + else: + return data + return f"" + + def _convert_seq_expr(self, seq: relax.SeqExpr, args: List[Any]) -> Any: + """Convert a Relax sequence expression.""" + # Convert blocks + for block in seq.blocks: + if hasattr(block, 'bindings'): + for binding in block.bindings: + if isinstance(binding, relax.VarBinding): + var_name = binding.var.name_hint + value = self.convert_expr(binding.value, args) + self.variable_map[var_name] = value + + # Convert body + return self.convert_expr(seq.body, args) + + def _convert_tuple(self, tuple_expr: relax.Tuple, args: List[Any]) -> Any: + """Convert a Relax tuple to Python tuple.""" + elements = [self.convert_expr(elem, args) for elem in tuple_expr.fields] + return tuple(elements) + + def _convert_tuple_get_item(self, get_item: relax.TupleGetItem, args: List[Any]) -> Any: + """Convert a Relax tuple get item to Python equivalent.""" + tuple_expr = self.convert_expr(get_item.tuple_value, args) + index = get_item.index + return f"" + + def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any: + """Convert a Relax if expression to Python equivalent.""" + condition = self.convert_expr(if_expr.cond, args) + true_branch = self.convert_expr(if_expr.true_branch, args) + false_branch = self.convert_expr(if_expr.false_branch, args) + return f"" + + def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert expand_dims to torch.unsqueeze with proper axis handling.""" + if len(call.args) < 1: + return f"" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get the axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, '__iter__') and not isinstance(axis, str): + # It's an array/list, take the first element + axis = list(axis)[0] if len(axis) > 0 else None + + # Handle TVM types + if hasattr(axis, 'value'): + # It's a TVM IntImm or similar, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is None: + return f"" + + # Use torch.unsqueeze with the correct axis + return torch.unsqueeze(tensor_arg, dim=axis) + + def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert reduction operations with axis and keepdims parameters.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis and keepdims from call.attrs + axis = None + keepdims = False + + if call.attrs: + if hasattr(call.attrs, 'axis') and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, '__iter__') and not isinstance(axis, str): + # It's an array/list, convert to list of ints + axis = [int(item.value) if hasattr(item, 'value') else int(item) for item in axis] + elif hasattr(axis, 'value'): + # It's a TVM IntImm, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if hasattr(call.attrs, 'keepdims'): + keepdims = bool(call.attrs.keepdims) + + # Get the PyTorch function + func = getattr(torch, op_name) + + # Call with appropriate parameters + if axis is not None: + # For max and min, PyTorch returns (values, indices) tuple when dim is specified + if op_name in ["max", "min"]: + if isinstance(axis, list) and len(axis) == 1: + axis = axis[0] + elif isinstance(axis, list) and len(axis) > 1: + axis = axis[0] + result = func(tensor_arg, axis, keepdim=keepdims) + if isinstance(result, tuple): + return result[0] + else: + return result + else: + return func(tensor_arg, dim=axis, keepdim=keepdims) + else: + return func(tensor_arg) + + def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any: + """Convert squeeze to torch.squeeze with proper axis handling.""" + if len(call.args) < 1: + return f"" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, 'axis') and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, '__iter__') and not isinstance(axis, str): + axis = [int(item.value) if hasattr(item, 'value') else int(item) for item in axis] + elif hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Call torch.squeeze with appropriate parameters + if axis is not None: + return torch.squeeze(tensor_arg, dim=axis) + else: + return torch.squeeze(tensor_arg) + + def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert tensor operations like concat, split, stack.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert arguments + converted_args = [self.convert_expr(arg, args) for arg in call.args] + + if op_name == "concat": + # torch.cat(tensors, dim=0) + # In Relax, concat takes a tuple of tensors as first argument + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + # This is a tuple of tensors + tensors = converted_args[0] + else: + # Direct tensor arguments + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.cat(tensors, dim=axis) + + elif op_name == "split": + # torch.split(tensor, split_size_or_sections, dim=0) + tensor = converted_args[0] + split_size = converted_args[1] if len(converted_args) > 1 else 1 + axis = 0 + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Handle indices_or_sections parameter + if call.attrs and hasattr(call.attrs, 'indices_or_sections'): + indices_or_sections = call.attrs.indices_or_sections + if hasattr(indices_or_sections, 'value'): + indices_or_sections = int(indices_or_sections.value) + elif isinstance(indices_or_sections, (int, float)): + indices_or_sections = int(indices_or_sections) + + # If indices_or_sections is an integer, it means split into N equal parts + if isinstance(indices_or_sections, int): + total_size = tensor.shape[axis] + split_size = total_size // indices_or_sections + return torch.split(tensor, split_size, dim=axis) + else: + # If it's a list, use it directly + return torch.split(tensor, indices_or_sections, dim=axis) + else: + return torch.split(tensor, split_size, dim=axis) + + elif op_name == "stack": + # torch.stack(tensors, dim=0) + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + tensors = converted_args[0] + else: + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.stack(tensors, dim=axis) + + else: + return f"<{op_name}_error: unsupported operation>" + + def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any: + """Convert reshape operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + shape_arg = call.args[1] + + # Convert shape argument to Python tuple + if isinstance(shape_arg, relax.ShapeExpr): + if hasattr(shape_arg, 'values'): + shape = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in shape_arg.values) + else: + shape = (int(shape_arg),) + elif isinstance(shape_arg, relax.Constant): + # Constant tensor case + shape_data = shape_arg.data.numpy() + shape = tuple(int(v) for v in shape_data) + else: + # Try to convert as expression + converted_shape = self.convert_expr(shape_arg, args) + if isinstance(converted_shape, (list, tuple)): + shape = tuple(int(v) for v in converted_shape) + else: + shape = (int(converted_shape),) + + return torch.reshape(tensor_arg, shape) + + + def _convert_permute_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert permute_dims operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axes from call.attrs + if call.attrs and hasattr(call.attrs, 'axes'): + axes = call.attrs.axes + # Handle TVM Array type + if hasattr(axes, '__iter__') and not isinstance(axes, str): + # Convert TVM Array or Python list/tuple to tuple of ints + axes = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in axes) + elif isinstance(axes, (list, tuple)): + axes = tuple(int(v) for v in axes) + else: + axes = (int(axes),) + else: + return "" + + return torch.permute(tensor_arg, axes) + + def _convert_take(self, call: relax.Call, args: List[Any]) -> Any: + """Convert take operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + indices_arg = self.convert_expr(call.args[1], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Use advanced indexing for specific axis + if axis == 0: + return tensor_arg[indices_arg] + else: + # For other axes, we need to use torch.index_select + return torch.index_select(tensor_arg, dim=axis, index=indices_arg) + else: + # No axis specified, use torch.take (flattens the tensor) + return torch.take(tensor_arg, indices_arg) + + def _convert_flip(self, call: relax.Call, args: List[Any]) -> Any: + """Convert flip operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Convert single axis to list for torch.flip + dims = [axis] + else: + # Default: flip all dimensions + dims = list(range(tensor_arg.dim())) + + return torch.flip(tensor_arg, dims=dims) + + def _convert_tile(self, call: relax.Call, args: List[Any]) -> Any: + """Convert tile operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats from call.attrs + if call.attrs and hasattr(call.attrs, 'repeats'): + repeats = call.attrs.repeats + # Handle TVM Array type + if hasattr(repeats, '__iter__') and not isinstance(repeats, str): + repeats = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in repeats) + elif isinstance(repeats, (list, tuple)): + repeats = tuple(int(v) for v in repeats) + else: + repeats = (int(repeats),) + else: + return "" + + return torch.tile(tensor_arg, dims=repeats) + + def _convert_repeat(self, call: relax.Call, args: List[Any]) -> Any: + """Convert repeat operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats and axis from call.attrs + repeats = 1 + axis = None + + if call.attrs and hasattr(call.attrs, 'repeats'): + repeats = call.attrs.repeats + if hasattr(repeats, 'value'): + repeats = int(repeats.value) + elif isinstance(repeats, (int, float)): + repeats = int(repeats) + + if call.attrs and hasattr(call.attrs, 'axis'): + axis = call.attrs.axis + if hasattr(axis, 'value'): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return torch.repeat_interleave(tensor_arg, repeats=repeats, dim=axis) + else: + return torch.repeat_interleave(tensor_arg, repeats=repeats) + + def _convert_shape_expr(self, shape_expr: relax.ShapeExpr) -> Any: + """Convert a Relax shape expression to Python equivalent.""" + if hasattr(shape_expr, 'values'): + return f"" + return f"" + + +def convert_relax_to_pyfunc(ir_module: IRModule, relax_function_names: Union[str, List[str]]) -> IRModule: + """Convert Relax functions to Python functions. + + Args: + ir_module: The IRModule containing Relax functions + relax_function_names: Name(s) of Relax functions to convert + + Returns: + IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "my_function") + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, ["func1", "func2"]) + """ + converter = RelaxToPyFuncConverter(ir_module) + return converter.convert(relax_function_names) diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..6dc72adafdb6 --- /dev/null +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -0,0 +1,813 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Comprehensive test cases for Relax to PyFunc converter. +Tests all major features including basic operations, call_tir, call_dps_packed, and symbolic shapes. +""" + + +import pytest +import torch +import torch.nn.functional as F +import numpy as np + + +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R +from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter + + +@I.ir_module +class ComprehensiveTestModule: + """Test module covering all converter features.""" + + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for addition.""" + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + for i in range(5): + out[i] = x[i] + y[i] + + @T.prim_func + def mul_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for multiplication.""" + x = T.match_buffer(var_x, (3, 4), "float32") + y = T.match_buffer(var_y, (3, 4), "float32") + out = T.match_buffer(var_out, (3, 4), "float32") + for i in range(3): + for j in range(4): + out[i, j] = x[i, j] * y[i, j] + + @R.function + def simple_add( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.add(x, y) + + @R.function + def with_relu(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.relu(x) + + @R.function + def with_call_tir( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + cls = ComprehensiveTestModule + return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32")) + + @R.function + def with_call_dps_packed( + x: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.call_dps_packed("my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32")) + + @R.function + def complex_function( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + added = R.add(x, y) + relued = R.nn.relu(added) + cls = ComprehensiveTestModule + tir_result = R.call_tir(cls.add_tir, (relued, y), out_sinfo=R.Tensor((5,), "float32")) + return R.nn.relu(tir_result) + + @R.function + def symbolic_add( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + @R.function + def symbolic_matmul( + x: R.Tensor(("batch", "m", "k"), "float32"), + y: R.Tensor(("batch", "k", "n"), "float32") + ) -> R.Tensor(("batch", "m", "n"), "float32"): + return R.matmul(x, y) + + @R.function + def symbolic_expand_dims( + x: R.Tensor(("batch", "seq_len"), "float32") + ) -> R.Tensor(("batch", "seq_len", 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def multi_ops( + x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") + ) -> R.Tensor((3, 4), "float32"): + added = R.add(x, y) + multiplied = R.multiply(added, y) + powered = R.power(multiplied, R.const(2.0)) + maxed = R.maximum(powered, x) + return maxed + + @R.function + def reduction_ops(x: R.Tensor((5,), "float32")) -> R.Tensor((), "float32"): + sum_val = R.sum(x) + mean_val = R.mean(x) + max_val = R.max(x) + return R.add(R.add(sum_val, mean_val), max_val) + + @R.function + def comparison_ops( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + eq_val = R.equal(x, y) + gt_val = R.greater(x, y) + return R.logical_and(eq_val, gt_val) + + @R.function + def test_reshape(x: R.Tensor((2, 3), "float32")) -> R.Tensor((6,), "float32"): + return R.reshape(x, (6,)) + + @R.function + def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 2, 3), "float32"): + return R.permute_dims(x, axes=[2, 0, 1]) + + @R.function + def test_concat(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"): + return R.concat((x, y), axis=0) + + @R.function + def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple: + return R.split(x, indices_or_sections=2, axis=0) + + @R.function + def test_stack(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 2, 3), "float32"): + return R.stack((x, y), axis=1) + + @R.function + def test_take(x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64")) -> R.Tensor((2,), "float32"): + return R.take(x, indices, axis=0) + + @R.function + def test_flip(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + return R.flip(x, axis=1) + + @R.function + def test_tile(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 6), "float32"): + return R.tile(x, (2, 2)) + + @R.function + def test_repeat(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"): + return R.repeat(x, repeats=2, axis=0) + + @R.function + def test_expand_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3, 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def test_squeeze(x: R.Tensor((2, 3, 1), "float32")) -> R.Tensor((2, 3), "float32"): + return R.squeeze(x, axis=2) + + @R.function + def test_sum_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.sum(x, axis=0) + + @R.function + def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.max(x, axis=0) + + +def create_mock_packed_function(): + """Create a mock packed function for testing.""" + def mock_softmax(x, axis): + """Mock softmax function that just returns the input.""" + return x + + # Register the function globally + tvm.register_func("my_softmax", mock_softmax) + + +class TestRelaxToPyFuncConverter: + """Comprehensive test class for Relax to PyFunc converter.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ComprehensiveTestModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + create_mock_packed_function() + + def test_basic_operations(self): + """Test basic arithmetic operations.""" + converted_ir_mod = self.converter.convert(["simple_add", "with_relu"]) + + # Test simple_add + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['simple_add'](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test with_relu + x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs['with_relu'](x_neg) + expected = torch.nn.functional.relu(x_neg) + assert torch.allclose(result, expected) + + def test_call_tir(self): + """Test call_tir functionality with DLPack conversion.""" + converted_ir_mod = self.converter.convert(["with_call_tir"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['with_call_tir'](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + assert result.shape == expected.shape + + def test_call_dps_packed(self): + """Test call_dps_packed functionality.""" + converted_ir_mod = self.converter.convert(["with_call_dps_packed"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['with_call_dps_packed'](x) + expected = x + assert torch.allclose(result, expected) + + def test_complex_function(self): + """Test complex function with multiple operations.""" + converted_ir_mod = self.converter.convert(["complex_function"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['complex_function'](x, y) + + # Expected: relu(add(relu(add(x, y)), y)) + step1 = torch.add(x, y) + step2 = torch.nn.functional.relu(step1) + step3 = torch.add(step2, y) # TIR call + expected = torch.nn.functional.relu(step3) + + assert torch.allclose(result, expected) + + def test_symbolic_shapes(self): + """Test symbolic shape handling.""" + converted_ir_mod = self.converter.convert([ + "symbolic_add", "symbolic_matmul", "symbolic_expand_dims" + ]) + + # Test symbolic_add + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + result = converted_ir_mod.pyfuncs['symbolic_add'](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test symbolic_matmul + x = torch.randn(2, 3, 4, dtype=torch.float32) # (batch=2, m=3, k=4) + y = torch.randn(2, 4, 5, dtype=torch.float32) # (batch=2, k=4, n=5) + result = converted_ir_mod.pyfuncs['symbolic_matmul'](x, y) + expected = torch.matmul(x, y) + assert torch.allclose(result, expected) + assert result.shape == (2, 3, 5) + + # Test symbolic_expand_dims + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + result = converted_ir_mod.pyfuncs['symbolic_expand_dims'](x) + expected = torch.unsqueeze(x, dim=2) + assert torch.allclose(result, expected) + assert result.shape == (2, 2, 1) + + def test_multi_operations(self): + """Test multiple operations in sequence.""" + converted_ir_mod = self.converter.convert(["multi_ops"]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], dtype=torch.float32) + y = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['multi_ops'](x, y) + + # Expected: maximum(power(multiply(add(x, y), y), 2), x) + step1 = torch.add(x, y) + step2 = torch.mul(step1, y) + step3 = torch.pow(step2, 2.0) + expected = torch.maximum(step3, x) + + assert torch.allclose(result, expected) + + def test_reduction_operations(self): + """Test reduction operations.""" + converted_ir_mod = self.converter.convert(["reduction_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['reduction_ops'](x) + + # Expected: sum(x) + mean(x) + max(x) + expected = torch.sum(x) + torch.mean(x) + torch.max(x) + + assert torch.allclose(result, expected) + assert result.shape == () + + def test_comparison_operations(self): + """Test comparison operations.""" + converted_ir_mod = self.converter.convert(["comparison_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs['comparison_ops'](x, y) + + # Expected: logical_and(equal(x, y), greater(x, y)) + eq_val = torch.eq(x, y) + gt_val = torch.gt(x, y) + expected = torch.logical_and(eq_val, gt_val) + + assert torch.allclose(result, expected) + assert result.dtype == torch.bool + + def test_operator_mapping_completeness(self): + """Test that operator mapping is comprehensive.""" + operator_map = RelaxToPyFuncConverter._get_relax_to_pytorch_operator_map() + + # Check that we have a good number of operators + assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}" + + # Check key operator categories + binary_ops = [op for op in operator_map.keys() if op.startswith("relax.") and not op.startswith("relax.nn.")] + nn_ops = [op for op in operator_map.keys() if op.startswith("relax.nn.")] + + assert len(binary_ops) > 20, f"Expected >20 binary ops, got {len(binary_ops)}" + assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}" + + # Check specific important operators + important_ops = [ + "relax.add", "relax.multiply", "relax.nn.relu", "relax.nn.softmax", + "relax.matmul", "relax.reshape", "relax.sum", "relax.mean" + ] + + for op in important_ops: + assert op in operator_map, f"Missing important operator: {op}" + + def test_error_handling(self): + """Test error handling for invalid inputs.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Test with wrong number of arguments + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + with pytest.raises(ValueError, match="Expected 2 arguments"): + converted_ir_mod.pyfuncs['simple_add'](x) # Missing second argument + + # Test with incompatible shapes - this should raise a RuntimeError + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.0], dtype=torch.float32) # Different shape + + # This should raise a RuntimeError because shapes don't match + with pytest.raises(RuntimeError, match="The size of tensor a"): + converted_ir_mod.pyfuncs['simple_add'](x, y) + + def test_conversion_metadata(self): + """Test that conversion preserves metadata correctly.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Check that pyfuncs attribute exists + assert hasattr(converted_ir_mod, 'pyfuncs') + assert 'simple_add' in converted_ir_mod.pyfuncs + + # Check function metadata + pyfunc = converted_ir_mod.pyfuncs['simple_add'] + assert hasattr(pyfunc, '__name__') + assert hasattr(pyfunc, '__doc__') + assert pyfunc.__name__ == 'simple_add' + + def test_tensor_operations(self): + """Test tensor manipulation operations.""" + converted_ir_mod = self.converter.convert([ + "test_reshape", "test_permute_dims", "test_concat", "test_split", + "test_stack", "test_take", "test_flip", "test_tile", "test_repeat", + "test_expand_dims", "test_squeeze", "test_sum_with_axis", "test_max_with_axis" + ]) + + # Test reshape + x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result1 = converted_ir_mod.pyfuncs["test_reshape"](x1) + expected1 = torch.reshape(x1, (6,)) + assert torch.allclose(result1, expected1), "Reshape operation failed" + + # Test permute_dims + x2 = torch.randn(2, 3, 4) + result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2) + expected2 = torch.permute(x2, (2, 0, 1)) + assert torch.allclose(result2, expected2), "Permute_dims operation failed" + + # Test concat + x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3) + expected3 = torch.cat([x3, y3], dim=0) + assert torch.allclose(result3, expected3), "Concat operation failed" + + # Test split + x4 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result4 = converted_ir_mod.pyfuncs["test_split"](x4) + expected4 = torch.split(x4, 2, dim=0) + assert len(result4) == len(expected4), "Split operation failed - wrong number of tensors" + for r, e in zip(result4, expected4): + assert torch.allclose(r, e), "Split operation failed - tensor mismatch" + + # Test stack + x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5) + expected5 = torch.stack([x5, y5], dim=1) + assert torch.allclose(result5, expected5), "Stack operation failed" + + # Test take + x6 = torch.tensor([[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]], dtype=torch.float32) + indices = torch.tensor([0, 2], dtype=torch.int64) + result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices) + expected6 = x6[indices] + assert torch.allclose(result6, expected6), "Take operation failed" + + # Test flip + x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result7 = converted_ir_mod.pyfuncs["test_flip"](x7) + expected7 = torch.flip(x7, dims=[1]) + assert torch.allclose(result7, expected7), "Flip operation failed" + + # Test tile + x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result8 = converted_ir_mod.pyfuncs["test_tile"](x8) + expected8 = torch.tile(x8, (2, 2)) + assert torch.allclose(result8, expected8), "Tile operation failed" + + # Test repeat + x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result9 = converted_ir_mod.pyfuncs["test_repeat"](x9) + expected9 = torch.repeat_interleave(x9, repeats=2, dim=0) + assert torch.allclose(result9, expected9), "Repeat operation failed" + + # Test expand_dims + x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10) + expected10 = torch.unsqueeze(x10, dim=2) + assert torch.allclose(result10, expected10), "Expand_dims operation failed" + + # Test squeeze + x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float32) + result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11) + expected11 = torch.squeeze(x11, dim=2) + assert torch.allclose(result11, expected11), "Squeeze operation failed" + + # Test sum with axis + x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12) + expected12 = torch.sum(x12, dim=0) + assert torch.allclose(result12, expected12), "Sum with axis operation failed" + + # Test max with axis + x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13) + expected13 = torch.max(x13, dim=0)[0] # torch.max returns (values, indices) + assert torch.allclose(result13, expected13), "Max with axis operation failed" + + +@I.ir_module +class ExtendedOperatorsModule: + """Extended test module with additional operators not covered in ComprehensiveTestModule.""" + + # Unary operations not covered in ComprehensiveTestModule + @R.function + def test_abs(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.abs(x) + + @R.function + def test_neg(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.negative(x) + + @R.function + def test_exp(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.exp(x) + + @R.function + def test_log(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.log(x) + + @R.function + def test_sqrt(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sqrt(x) + + @R.function + def test_sin(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sin(x) + + @R.function + def test_cos(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.cos(x) + + @R.function + def test_tanh(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.tanh(x) + + @R.function + def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sigmoid(x) + + # Comparison operations not covered in ComprehensiveTestModule + @R.function + def test_less(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "bool"): + return R.less(x, y) + + @R.function + def test_not_equal(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "bool"): + return R.not_equal(x, y) + + # Binary operations not covered in ComprehensiveTestModule + @R.function + def test_multiply(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.multiply(x, y) + + @R.function + def test_divide(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.divide(x, y) + + @R.function + def test_power(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.power(x, y) + + @R.function + def test_maximum(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.maximum(x, y) + + @R.function + def test_minimum(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.minimum(x, y) + + @R.function + def test_subtract(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.subtract(x, y) + + # Additional tensor operations with different parameters + @R.function + def test_transpose_2d(x: R.Tensor((2, 4), "float32")) -> R.Tensor((4, 2), "float32"): + return R.permute_dims(x, axes=[1, 0]) + + @R.function + def test_mean_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.mean(x, axis=0) + + @R.function + def test_min_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.min(x, axis=0) + + # Neural network operations not covered in ComprehensiveTestModule + @R.function + def test_gelu_nn(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.gelu(x) + + @R.function + def test_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.softmax(x, axis=1) + + @R.function + def test_log_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.log_softmax(x, axis=1) + + # Advanced tensor operations with different parameters + @R.function + def test_tile_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 9), "float32"): + return R.tile(x, (2, 3)) + + @R.function + def test_repeat_axis(x: R.Tensor((3,), "float32")) -> R.Tensor((6,), "float32"): + return R.repeat(x, repeats=2, axis=0) + + +class TestExtendedOperators: + """Test class for extended operator coverage.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ExtendedOperatorsModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + + def test_unary_operations(self): + """Test unary operations.""" + converted_ir_mod = self.converter.convert([ + "test_abs", "test_neg", "test_exp", "test_log", "test_sqrt" + ]) + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + + # Test abs + result = converted_ir_mod.pyfuncs['test_abs'](x) + expected = torch.abs(x) + assert torch.allclose(result, expected) + + # Test negative + result = converted_ir_mod.pyfuncs['test_neg'](x) + expected = torch.neg(x) + assert torch.allclose(result, expected) + + # Test exp + result = converted_ir_mod.pyfuncs['test_exp'](x) + expected = torch.exp(x) + assert torch.allclose(result, expected) + + # Test log (with positive values) + x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs['test_log'](x_pos) + expected = torch.log(x_pos) + assert torch.allclose(result, expected) + + # Test sqrt + result = converted_ir_mod.pyfuncs['test_sqrt'](x_pos) + expected = torch.sqrt(x_pos) + assert torch.allclose(result, expected) + + def test_trigonometric_operations(self): + """Test trigonometric operations.""" + converted_ir_mod = self.converter.convert([ + "test_sin", "test_cos", "test_tanh", "test_sigmoid" + ]) + + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32) + + # Test sin + result = converted_ir_mod.pyfuncs['test_sin'](x) + expected = torch.sin(x) + assert torch.allclose(result, expected) + + # Test cos + result = converted_ir_mod.pyfuncs['test_cos'](x) + expected = torch.cos(x) + assert torch.allclose(result, expected) + + # Test tanh + result = converted_ir_mod.pyfuncs['test_tanh'](x) + expected = torch.tanh(x) + assert torch.allclose(result, expected) + + # Test sigmoid + result = converted_ir_mod.pyfuncs['test_sigmoid'](x) + expected = torch.sigmoid(x) + assert torch.allclose(result, expected) + + def test_comparison_operations(self): + """Test comparison operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert([ + "test_less", "test_not_equal" + ]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test less + result = converted_ir_mod.pyfuncs['test_less'](x, y) + expected = torch.lt(x, y) + assert torch.equal(result, expected) + + # Test not equal + result = converted_ir_mod.pyfuncs['test_not_equal'](x, y) + expected = torch.ne(x, y) + assert torch.equal(result, expected) + + def test_binary_operations(self): + """Test binary operations.""" + converted_ir_mod = self.converter.convert([ + "test_multiply", "test_divide", "test_power", "test_maximum", "test_minimum", "test_subtract" + ]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test multiply + result = converted_ir_mod.pyfuncs['test_multiply'](x, y) + expected = torch.mul(x, y) + assert torch.allclose(result, expected) + + # Test divide + result = converted_ir_mod.pyfuncs['test_divide'](x, y) + expected = torch.div(x, y) + assert torch.allclose(result, expected) + + # Test power + result = converted_ir_mod.pyfuncs['test_power'](x, y) + expected = torch.pow(x, y) + assert torch.allclose(result, expected) + + # Test maximum + result = converted_ir_mod.pyfuncs['test_maximum'](x, y) + expected = torch.maximum(x, y) + assert torch.allclose(result, expected) + + # Test minimum + result = converted_ir_mod.pyfuncs['test_minimum'](x, y) + expected = torch.minimum(x, y) + assert torch.allclose(result, expected) + + # Test subtract + result = converted_ir_mod.pyfuncs['test_subtract'](x, y) + expected = torch.sub(x, y) + assert torch.allclose(result, expected) + + def test_tensor_operations(self): + """Test tensor operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert([ + "test_transpose_2d" + ]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test transpose + result = converted_ir_mod.pyfuncs['test_transpose_2d'](x) + expected = torch.transpose(x, 0, 1) + assert torch.allclose(result, expected) + assert result.shape == (4, 2) + + def test_reduction_operations(self): + """Test reduction operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert([ + "test_mean_axis", "test_min_axis" + ]) + + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + + # Test mean + result = converted_ir_mod.pyfuncs['test_mean_axis'](x) + expected = torch.mean(x, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (3,) + + # Test min + result = converted_ir_mod.pyfuncs['test_min_axis'](x) + expected = torch.min(x, dim=0)[0] + assert torch.allclose(result, expected) + assert result.shape == (3,) + + def test_neural_network_operations(self): + """Test neural network operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert([ + "test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn" + ]) + + x = torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32) + + # Test gelu + result = converted_ir_mod.pyfuncs['test_gelu_nn'](x[0]) + expected = F.gelu(x[0]) + assert torch.allclose(result, expected) + + # Test softmax + result = converted_ir_mod.pyfuncs['test_softmax_nn'](x) + expected = F.softmax(x, dim=1) + assert torch.allclose(result, expected) + + # Test log_softmax + result = converted_ir_mod.pyfuncs['test_log_softmax_nn'](x) + expected = F.log_softmax(x, dim=1) + assert torch.allclose(result, expected) + + def test_advanced_tensor_operations(self): + """Test advanced tensor operations with different parameters.""" + converted_ir_mod = self.converter.convert([ + "test_tile_dims", "test_repeat_axis" + ]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test tile with different dimensions + result = converted_ir_mod.pyfuncs['test_tile_dims'](x) + expected = torch.tile(x, (2, 3)) + assert torch.allclose(result, expected) + assert result.shape == (4, 12) + + # Test repeat with different parameters + x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs['test_repeat_axis'](x_1d) + expected = torch.repeat_interleave(x_1d, repeats=2, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (6,) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 9f51679cbab311b5c888a6fb1898286da5d99dd5 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 4 Sep 2025 21:42:10 +0800 Subject: [PATCH 2/4] finish2 --- .../relax_to_pyfunc_conversion_refactored.py | 246 ------------------ 1 file changed, 246 deletions(-) delete mode 100644 examples/relax_to_pyfunc_conversion_refactored.py diff --git a/examples/relax_to_pyfunc_conversion_refactored.py b/examples/relax_to_pyfunc_conversion_refactored.py deleted file mode 100644 index 8d1091b9d130..000000000000 --- a/examples/relax_to_pyfunc_conversion_refactored.py +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env python3 -""" -Example: Converting Relax Functions to Python Functions (Refactored) - -This example demonstrates the new refactored architecture for converting Relax functions -to Python functions. The key improvement is that the converter now works with pure -IRModule objects, making it more modular and reusable. - -Key Features: -1. Pure IRModule conversion (no BasePyModule dependency) -2. Independent converter class -3. Convenience function for direct usage -4. BasePyModule integration for backward compatibility -""" - -import tvm -from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter, convert_relax_to_pyfunc -from tvm.relax.base_py_module import BasePyModule -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import relax as R - - -@I.ir_module -class ExampleModule(BasePyModule): - """Example module with various Relax functions for conversion.""" - - @T.prim_func - def custom_add(var_x: T.handle, var_y: T.handle, var_out: T.handle): - """Custom TIR function for addition.""" - x = T.match_buffer(var_x, (5,), "float32") - y = T.match_buffer(var_y, (5,), "float32") - out = T.match_buffer(var_out, (5,), "float32") - - for i in range(5): - out[i] = x[i] + y[i] - - @R.function - def simple_math( - x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - """Simple mathematical operations.""" - # Basic arithmetic - add_result = R.add(x, y) - multiply_result = R.multiply(add_result, R.const(2.0, "float32")) - return multiply_result - - @R.function - def neural_network_layer( - x: R.Tensor((10, 20), "float32"), - weight: R.Tensor((20, 10), "float32"), - bias: R.Tensor((10,), "float32") - ) -> R.Tensor((10, 10), "float32"): - """Neural network layer with linear transformation and activation.""" - # Linear transformation - linear_out = R.matmul(x, weight) - # Add bias - biased_out = R.add(linear_out, bias) - # Apply ReLU activation - activated_out = R.nn.relu(biased_out) - return activated_out - - @R.function - def with_tir_call( - x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - """Function that calls a TIR function.""" - return R.call_tir(custom_add, (x, y), out_sinfo=R.Tensor((5,), "float32")) - - @R.function - def with_conditionals( - x: R.Tensor((5,), "float32"), threshold: R.Tensor((), "float32") - ) -> R.Tensor((5,), "float32"): - """Function with conditional logic.""" - # Create condition - condition = R.greater(x, threshold) - # True branch: multiply by 2 - true_result = R.multiply(x, R.const(2.0, "float32")) - # False branch: multiply by 0.5 - false_result = R.multiply(x, R.const(0.5, "float32")) - # Apply conditional - return R.where(condition, true_result, false_result) - - -def demo_pure_converter(): - """Demonstrate the pure converter approach.""" - print("=" * 80) - print("1. Pure Converter Approach (Recommended)") - print("=" * 80) - - ir_mod = ExampleModule - - # Create converter directly from IRModule - converter = RelaxToPyFuncConverter(ir_mod) - - print("\nConverting individual functions:") - print("-" * 50) - - # Convert single function - converted_ir_mod = converter.convert("simple_math") - simple_math_func = converted_ir_mod.pyfuncs["simple_math"] - result = simple_math_func("arg1", "arg2") - print(f"simple_math: {result}") - - # Convert multiple functions - converted_ir_mod = converter.convert(["neural_network_layer", "with_conditionals"]) - print(f"Converted functions: {list(converted_ir_mod.pyfuncs.keys())}") - - # Test converted functions - for func_name, func in converted_ir_mod.pyfuncs.items(): - if func_name == "neural_network_layer": - result = func("input", "weight", "bias") - elif func_name == "with_conditionals": - result = func("x", "threshold") - else: - result = func("arg1", "arg2") - print(f"{func_name}: {result}") - - -def demo_convenience_function(): - """Demonstrate the convenience function approach.""" - print("\n" + "=" * 80) - print("2. Convenience Function Approach") - print("=" * 80) - - ir_mod = ExampleModule - - print("\nUsing convenience function:") - print("-" * 50) - - # Convert using convenience function - converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "simple_math") - simple_math_func = converted_ir_mod.pyfuncs["simple_math"] - result = simple_math_func("arg1", "arg2") - print(f"simple_math: {result}") - - # Convert multiple functions - converted_ir_mod = convert_relax_to_pyfunc(ir_mod, [ - "simple_math", - "neural_network_layer", - "with_conditionals" - ]) - print(f"Converted functions: {list(converted_ir_mod.pyfuncs.keys())}") - - -def demo_basepymodule_integration(): - """Demonstrate BasePyModule integration for backward compatibility.""" - print("\n" + "=" * 80) - print("3. BasePyModule Integration (Backward Compatibility)") - print("=" * 80) - - ir_mod = ExampleModule - device = tvm.cpu() - module = BasePyModule(ir_mod, device) - - print("\nUsing BasePyModule method:") - print("-" * 50) - - # Convert using BasePyModule method - converted_module = module.convert_relax_to_pyfunc("simple_math") - simple_math_func = converted_module.ir_mod.pyfuncs["simple_math"] - result = simple_math_func("arg1", "arg2") - print(f"simple_math: {result}") - - # Convert multiple functions - converted_module = module.convert_relax_to_pyfunc([ - "simple_math", - "neural_network_layer" - ]) - print(f"Converted functions: {list(converted_module.ir_mod.pyfuncs.keys())}") - - -def demo_operator_mapping(): - """Demonstrate the operator mapping functionality.""" - print("\n" + "=" * 80) - print("4. Operator Mapping") - print("=" * 80) - - ir_mod = ExampleModule - converter = RelaxToPyFuncConverter(ir_mod) - operator_map = converter.operator_map - - print(f"\nTotal operators mapped: {len(operator_map)}") - - # Show some key mappings - key_operators = [ - "relax.add", "relax.multiply", "relax.matmul", - "relax.nn.relu", "relax.nn.softmax", "relax.where" - ] - - print("\nKey operator mappings:") - for relax_op in key_operators: - if relax_op in operator_map: - pytorch_op = operator_map[relax_op] - print(f" {relax_op} -> {pytorch_op}") - - -def demo_architecture_benefits(): - """Demonstrate the benefits of the new architecture.""" - print("\n" + "=" * 80) - print("5. Architecture Benefits") - print("=" * 80) - - print("\nBenefits of the refactored architecture:") - print("-" * 50) - print("✓ Pure IRModule conversion (no BasePyModule dependency)") - print("✓ Independent converter class for reusability") - print("✓ Convenience function for direct usage") - print("✓ BasePyModule integration for backward compatibility") - print("✓ Better separation of concerns") - print("✓ Easier testing and maintenance") - print("✓ More modular design") - - print("\nUsage patterns:") - print("-" * 50) - print("# Pure converter (recommended)") - print("converter = RelaxToPyFuncConverter(ir_mod)") - print("converted_ir_mod = converter.convert(['func1', 'func2'])") - print() - print("# Convenience function") - print("converted_ir_mod = convert_relax_to_pyfunc(ir_mod, 'func1')") - print() - print("# BasePyModule integration") - print("module = BasePyModule(ir_mod, device)") - print("converted_module = module.convert_relax_to_pyfunc('func1')") - - -def main(): - """Main function demonstrating the refactored conversion process.""" - print("Relax to Python Function Conversion (Refactored Architecture)") - print("=" * 80) - - # Run all demos - demo_pure_converter() - demo_convenience_function() - demo_basepymodule_integration() - demo_operator_mapping() - demo_architecture_benefits() - - print("\n" + "=" * 80) - print("Refactored architecture demo completed successfully!") - print("=" * 80) - - -if __name__ == "__main__": - main() From 5576c336c10eb81e912d23d8052dba48f90cf77b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 4 Sep 2025 23:05:50 +0800 Subject: [PATCH 3/4] opmit&lint --- python/tvm/relax/relax_to_pyfunc_converter.py | 408 +++++++++-------- .../relax/test_relax_to_pyfunc_converter.py | 429 ++++++++++-------- 2 files changed, 451 insertions(+), 386 deletions(-) diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py index 2a26a35ee502..4306a6929fc3 100644 --- a/python/tvm/relax/relax_to_pyfunc_converter.py +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -31,30 +31,34 @@ class RelaxToPyFuncConverter: """Converter that works with IRModule to convert Relax functions to Python functions. - + This converter transforms Relax functions into Python functions that can be executed directly in Python/PyTorch environment. The conversion maps Relax operators to corresponding PyTorch APIs and handles special cases like call_tir and call_dps_packed. """ - + def __init__(self, ir_module: IRModule): """Initialize the converter with an IRModule. - + Args: ir_module: The IRModule containing Relax functions to convert """ self.ir_module = ir_module self.operator_map = self._get_relax_to_pytorch_operator_map() - + # Cache for RelaxExpressionConverter instances to avoid recreating them + self._converter_cache = {} + # Cache for operator mappings to avoid repeated lookups + self._op_cache = {} + def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: """Convert specified Relax functions to Python functions. - + Args: relax_function_names: Name(s) of Relax functions to convert - + Returns: Updated IRModule with converted Python functions stored in pyfuncs - + Example: >>> converter = RelaxToPyFuncConverter(ir_mod) >>> # Convert a single function @@ -64,66 +68,74 @@ def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: """ if isinstance(relax_function_names, str): relax_function_names = [relax_function_names] - + # Create a copy of the current IRModule new_ir_mod = self.ir_module.clone() - + # Initialize pyfuncs if not exists if not hasattr(new_ir_mod, "pyfuncs"): new_ir_mod.pyfuncs = {} - + # Get Relax function names from IRModule relax_func_names = [] for gv, func in self.ir_module.functions_items(): if isinstance(func, relax.Function): relax_func_names.append(gv.name_hint) - + # Convert each Relax function for func_name in relax_function_names: if func_name not in relax_func_names: raise ValueError(f"Relax function '{func_name}' not found in IRModule") - + # Get the Relax function relax_func = None for gv, func in self.ir_module.functions_items(): if gv.name_hint == func_name and isinstance(func, relax.Function): relax_func = func break - + if relax_func is None: raise ValueError(f"Could not find Relax function '{func_name}'") - + # Convert to Python function py_func = self._convert_relax_function_to_python(relax_func, func_name) - + # Store in pyfuncs new_ir_mod.pyfuncs[func_name] = py_func - + return new_ir_mod - - def _convert_relax_function_to_python(self, relax_func: relax.Function, func_name: str) -> callable: - """Convert a single Relax function to a Python function.""" + + def _convert_relax_function_to_python( + self, relax_func: relax.Function, func_name: str + ) -> callable: + """Convert a single Relax function to a Python function with caching.""" # Get function parameters params = relax_func.params - + # Create the Python function def converted_function(*args, **kwargs): """Converted Python function from Relax function.""" # Handle arguments if len(args) != len(params): raise ValueError(f"Expected {len(params)} arguments, got {len(args)}") - + + # Use cached converter or create new one + if func_name not in self._converter_cache: + self._converter_cache[func_name] = RelaxExpressionConverter( + self.operator_map, self.ir_module, self._op_cache + ) + # Execute the converted function body - converter = RelaxExpressionConverter(self.operator_map, self.ir_module) + converter = self._converter_cache[func_name] converter.current_params = params return converter.convert_expr(relax_func.body, args) - + # Set function metadata converted_function.__name__ = func_name converted_function.__doc__ = f"Converted Python function from Relax function: {func_name}" - + return converted_function - + @staticmethod def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: """Get the mapping from Relax operators to PyTorch operators.""" @@ -140,14 +152,12 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.mod": "torch.fmod", "relax.floor_mod": "torch.remainder", "relax.log_add_exp": "torch.logaddexp", - # Bitwise operations "relax.bitwise_and": "torch.bitwise_and", "relax.bitwise_or": "torch.bitwise_or", "relax.bitwise_xor": "torch.bitwise_xor", "relax.left_shift": "torch.left_shift", "relax.right_shift": "torch.right_shift", - # Unary operations "relax.abs": "torch.abs", "relax.negative": "torch.neg", @@ -167,7 +177,6 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.trunc": "torch.trunc", "relax.clip": "torch.clamp", "relax.bitwise_not": "torch.bitwise_not", - # Trigonometric functions "relax.acos": "torch.acos", "relax.asin": "torch.asin", @@ -178,13 +187,11 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.acosh": "torch.acosh", "relax.asinh": "torch.asinh", "relax.atanh": "torch.atanh", - # Special functions "relax.erf": "torch.erf", "relax.isfinite": "torch.isfinite", "relax.isinf": "torch.isinf", "relax.isnan": "torch.isnan", - # Neural network operations "relax.nn.relu": "F.relu", "relax.nn.relu6": "F.relu6", @@ -224,7 +231,6 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.nn.nll_loss": "F.nll_loss", "relax.nn.pad": "F.pad", "relax.nn.pixel_shuffle": "F.pixel_shuffle", - # Tensor operations "relax.matmul": "torch.matmul", "relax.linear": "F.linear", @@ -247,7 +253,6 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.meshgrid": "torch.meshgrid", "relax.one_hot": "F.one_hot", "relax.layout_transform": "torch.permute", # Approximate mapping - # Indexing operations "relax.take": "take", # Special handling needed "relax.gather_elements": "torch.gather", @@ -259,7 +264,6 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.strided_slice": "torch.slice", "relax.dynamic_strided_slice": "torch.slice", "relax.slice_scatter": "torch.scatter", - # Reduction operations "relax.sum": "sum", # Special handling needed "relax.mean": "mean", # Special handling needed @@ -272,7 +276,6 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.cumprod": "torch.cumprod", "relax.argmax": "torch.argmax", "relax.argmin": "torch.argmin", - # Comparison operations "relax.equal": "torch.eq", "relax.not_equal": "torch.ne", @@ -280,13 +283,11 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.greater_equal": "torch.ge", "relax.less": "torch.lt", "relax.less_equal": "torch.le", - # Logical operations "relax.logical_and": "torch.logical_and", "relax.logical_or": "torch.logical_or", "relax.logical_not": "torch.logical_not", "relax.logical_xor": "torch.logical_xor", - # Creation operations "relax.zeros": "torch.zeros", "relax.ones": "torch.ones", @@ -300,35 +301,27 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: "relax.tril": "torch.tril", "relax.triu": "torch.triu", "relax.hamming_window": "torch.hamming_window", - # Search operations "relax.where": "torch.where", "relax.bucketize": "torch.bucketize", "relax.nonzero": "torch.nonzero", "relax.unique": "torch.unique", - # Sorting operations "relax.sort": "torch.sort", "relax.argsort": "torch.argsort", "relax.topk": "torch.topk", - # Sampling operations "relax.multinomial_from_uniform": "torch.multinomial", - # Ternary operations "relax.ewise_fma": "torch.fma", # Approximate mapping - # Data type operations "relax.astype": "torch.to", "relax.wrap_param": "torch.tensor", - # Mask operations "relax.masked_fill": "torch.masked_fill", - # Quantization operations "relax.quantize": "torch.quantize_per_tensor", # Approximate mapping "relax.dequantize": "torch.dequantize", # Approximate mapping - # Special operations (handled separately) "relax.call_tir": "call_tir", "relax.call_tir_inplace": "call_tir_inplace", @@ -353,19 +346,27 @@ def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: class RelaxExpressionConverter: """Converter that transforms Relax expressions to Python/PyTorch code.""" - - def __init__(self, operator_map: Dict[str, str], ir_module: IRModule = None): + + def __init__( + self, + operator_map: Dict[str, str], + ir_module: IRModule = None, + op_cache: Dict[str, str] = None, + ): """Initialize the expression converter. - + Args: operator_map: Mapping from Relax operators to PyTorch operators ir_module: The IRModule containing TIR functions to compile + op_cache: Shared cache for operator mappings to avoid repeated lookups """ self.operator_map = operator_map self.variable_map: Dict[str, Any] = {} self.current_params: List[relax.Var] = [] self.ir_module = ir_module - + # Use shared operator cache or create new one + self._op_cache = op_cache if op_cache is not None else {} + def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any: """Convert a Relax expression to Python/PyTorch equivalent.""" if isinstance(expr, relax.Var): @@ -387,29 +388,29 @@ def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any: else: # Fallback for unknown expression types return f"" - + def _convert_var(self, var: relax.Var, args: List[Any]) -> Any: """Convert a Relax variable to Python equivalent.""" - if hasattr(var, 'name_hint'): + if hasattr(var, "name_hint"): var_name = var.name_hint - + # Check if it's a function parameter for i, param in enumerate(self.current_params): - if hasattr(param, 'name_hint') and param.name_hint == var_name: + if hasattr(param, "name_hint") and param.name_hint == var_name: return args[i] - + # Check if it's a bound variable if var_name in self.variable_map: return self.variable_map[var_name] - + # Return placeholder for unbound variables return f"" return f"" - + def _convert_call(self, call: relax.Call, args: List[Any]) -> Any: """Convert a Relax call to Python/PyTorch equivalent.""" op = call.op - + # Handle different types of calls if isinstance(op, relax.GlobalVar): # Function call @@ -422,12 +423,12 @@ def _convert_call(self, call: relax.Call, args: List[Any]) -> Any: return self._convert_extern_func_call(call, args) else: return f"" - + def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any: """Convert a Relax function call.""" func_name = call.op.name_hint call_args = [self.convert_expr(arg, args) for arg in call.args] - + # Handle special cases if func_name in ["call_tir", "call_tir_inplace"]: return self._convert_call_tir(call, args) @@ -436,14 +437,16 @@ def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any: else: # Regular function call return f"" - + def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: """Convert a Relax operator call to PyTorch equivalent.""" op_name = call.op.name call_args = [self.convert_expr(arg, args) for arg in call.args] - - # Get PyTorch equivalent - pytorch_op = self.operator_map.get(op_name) + + # Use cached operator mapping or look it up + if op_name not in self._op_cache: + self._op_cache[op_name] = self.operator_map.get(op_name) + pytorch_op = self._op_cache[op_name] if pytorch_op: try: # Handle special operations @@ -480,18 +483,18 @@ def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: # Neural network function func_name = pytorch_op[2:] # Remove "F." prefix func = getattr(F, func_name) - + # Special handling for functions that need dim parameter if func_name in ["softmax", "log_softmax"]: # Extract axis from call.attrs and convert to dim axis = None - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + if axis is not None: return func(call_args[0], dim=axis) else: @@ -517,26 +520,26 @@ def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: else: # Unknown operator return f"" - + def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> Any: """Convert an external function call.""" func_name = call.op.global_symbol call_args = [self.convert_expr(arg, args) for arg in call.args] - + if func_name in ["call_tir", "call_tir_inplace"]: return self._convert_call_tir(call, args) elif func_name in ["call_dps_packed", "call_pure_packed"]: return self._convert_call_dps_packed(call, args) else: return f"" - + def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: """Convert call_tir to Python equivalent with DLPack conversion.""" # Extract TIR function name and arguments tir_func = call.args[0] tir_args = call.args[1] if len(call.args) > 1 else [] out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None - + # Get function name if isinstance(tir_func, relax.GlobalVar): func_name = tir_func.name_hint @@ -546,17 +549,17 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: if isinstance(func_name, str) and func_name.startswith("<"): # If it's a placeholder, extract the name func_name = str(tir_func) - + # Convert arguments to PyTorch tensors converted_args = [self.convert_expr(arg, args) for arg in tir_args] - + try: # First, try to get the TIR function from the current IRModule tir_function = None if self.ir_module: # Look for the TIR function in the current IRModule for gv, func in self.ir_module.functions.items(): - if gv.name_hint == func_name and hasattr(func, 'body'): + if gv.name_hint == func_name and hasattr(func, "body"): try: # Compile the TIR function target = tvm.target.Target("llvm") @@ -564,16 +567,20 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: tir_function = tvm.compile(func, target=target) break except Exception as compile_e: - print(f"Warning: Failed to compile TIR function {func_name}: {compile_e}") + print( + f"Warning: Failed to compile TIR function {func_name}: {compile_e}" + ) continue - + # If not found in current module, try global registry if tir_function is None: tir_function = tvm.get_global_func(func_name) - + if tir_function is None: - return f"" - + return ( + f"" + ) + # Convert PyTorch tensors to TVM NDArrays via DLPack tvm_args = [] for arg in converted_args: @@ -583,41 +590,41 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: tvm_args.append(tvm_arg) else: tvm_args.append(arg) - + # For call_tir, we need to allocate output tensor output_shape = None - if out_sinfo and hasattr(out_sinfo, 'shape'): + if out_sinfo and hasattr(out_sinfo, "shape"): output_shape = out_sinfo.shape elif converted_args: # Use the shape of the first input tensor first_arg = converted_args[0] if isinstance(first_arg, torch.Tensor): output_shape = first_arg.shape - + if output_shape is None: return f"" - + # Allocate output tensor output_tensor = tvm.nd.array(tvm.nd.empty(output_shape, dtype="float32")) tvm_args.append(output_tensor) - + # Call the TIR function tir_function(*tvm_args) - + # The result is in the output_tensor we allocated # Convert result back to PyTorch tensor via DLPack return torch.from_dlpack(output_tensor.to_dlpack()) - + except Exception as e: return f"" - + def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: """Convert call_dps_packed to Python equivalent with DLPack conversion.""" # Extract packed function name and arguments packed_func = call.args[0] packed_args = call.args[1] if len(call.args) > 1 else [] out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None - + # Get function name if isinstance(packed_func, relax.GlobalVar): func_name = packed_func.name_hint @@ -625,16 +632,16 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: func_name = packed_func.global_symbol else: func_name = str(packed_func) - + # Convert arguments to PyTorch tensors converted_args = [self.convert_expr(arg, args) for arg in packed_args] - + try: # Get the packed function from TVM packed_function = tvm.get_global_func(func_name) if packed_function is None: return f"" - + # Convert PyTorch tensors to TVM NDArrays via DLPack tvm_args = [] for arg in converted_args: @@ -644,131 +651,133 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: tvm_args.append(tvm_arg) else: tvm_args.append(arg) - + # Call the packed function result = packed_function(*tvm_args) - + # Convert result back to PyTorch tensor via DLPack if isinstance(result, tvm.nd.NDArray): return torch.from_dlpack(result.to_dlpack()) else: return result - + except Exception as e: return f"" - + def _convert_constant(self, const: relax.Constant) -> Any: """Convert a Relax constant to Python equivalent.""" - if hasattr(const, 'data'): + if hasattr(const, "data"): data = const.data # Convert TVM NDArray to Python scalar if it's a scalar - if hasattr(data, 'numpy'): + if hasattr(data, "numpy"): numpy_data = data.numpy() if numpy_data.size == 1: return float(numpy_data.item()) else: # For multi-element arrays, convert to PyTorch tensor return torch.from_numpy(numpy_data) - elif hasattr(data, 'item'): + elif hasattr(data, "item"): # Single element tensor return data.item() else: return data return f"" - + def _convert_seq_expr(self, seq: relax.SeqExpr, args: List[Any]) -> Any: """Convert a Relax sequence expression.""" # Convert blocks for block in seq.blocks: - if hasattr(block, 'bindings'): + if hasattr(block, "bindings"): for binding in block.bindings: if isinstance(binding, relax.VarBinding): var_name = binding.var.name_hint value = self.convert_expr(binding.value, args) self.variable_map[var_name] = value - + # Convert body return self.convert_expr(seq.body, args) - + def _convert_tuple(self, tuple_expr: relax.Tuple, args: List[Any]) -> Any: """Convert a Relax tuple to Python tuple.""" elements = [self.convert_expr(elem, args) for elem in tuple_expr.fields] return tuple(elements) - + def _convert_tuple_get_item(self, get_item: relax.TupleGetItem, args: List[Any]) -> Any: """Convert a Relax tuple get item to Python equivalent.""" tuple_expr = self.convert_expr(get_item.tuple_value, args) index = get_item.index return f"" - + def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any: """Convert a Relax if expression to Python equivalent.""" condition = self.convert_expr(if_expr.cond, args) true_branch = self.convert_expr(if_expr.true_branch, args) false_branch = self.convert_expr(if_expr.false_branch, args) return f"" - + def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: """Convert expand_dims to torch.unsqueeze with proper axis handling.""" if len(call.args) < 1: return f"" - + # Convert the tensor argument tensor_arg = self.convert_expr(call.args[0], args) - + # Get the axis from call.attrs axis = None - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis # Handle different types of axis - if hasattr(axis, '__iter__') and not isinstance(axis, str): + if hasattr(axis, "__iter__") and not isinstance(axis, str): # It's an array/list, take the first element axis = list(axis)[0] if len(axis) > 0 else None - + # Handle TVM types - if hasattr(axis, 'value'): + if hasattr(axis, "value"): # It's a TVM IntImm or similar, get the value axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + if axis is None: return f"" - + # Use torch.unsqueeze with the correct axis return torch.unsqueeze(tensor_arg, dim=axis) - + def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) -> Any: """Convert reduction operations with axis and keepdims parameters.""" if len(call.args) < 1: return f"<{op_name}_error: insufficient arguments>" - + # Convert the tensor argument tensor_arg = self.convert_expr(call.args[0], args) - + # Get axis and keepdims from call.attrs axis = None keepdims = False - + if call.attrs: - if hasattr(call.attrs, 'axis') and call.attrs.axis is not None: + if hasattr(call.attrs, "axis") and call.attrs.axis is not None: axis = call.attrs.axis # Handle different types of axis - if hasattr(axis, '__iter__') and not isinstance(axis, str): + if hasattr(axis, "__iter__") and not isinstance(axis, str): # It's an array/list, convert to list of ints - axis = [int(item.value) if hasattr(item, 'value') else int(item) for item in axis] - elif hasattr(axis, 'value'): + axis = [ + int(item.value) if hasattr(item, "value") else int(item) for item in axis + ] + elif hasattr(axis, "value"): # It's a TVM IntImm, get the value axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - - if hasattr(call.attrs, 'keepdims'): + + if hasattr(call.attrs, "keepdims"): keepdims = bool(call.attrs.keepdims) - + # Get the PyTorch function func = getattr(torch, op_name) - + # Call with appropriate parameters if axis is not None: # For max and min, PyTorch returns (values, indices) tuple when dim is specified @@ -786,41 +795,41 @@ def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) return func(tensor_arg, dim=axis, keepdim=keepdims) else: return func(tensor_arg) - + def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any: """Convert squeeze to torch.squeeze with proper axis handling.""" if len(call.args) < 1: return f"" - + # Convert the tensor argument tensor_arg = self.convert_expr(call.args[0], args) - + # Get axis from call.attrs axis = None - if call.attrs and hasattr(call.attrs, 'axis') and call.attrs.axis is not None: + if call.attrs and hasattr(call.attrs, "axis") and call.attrs.axis is not None: axis = call.attrs.axis # Handle different types of axis - if hasattr(axis, '__iter__') and not isinstance(axis, str): - axis = [int(item.value) if hasattr(item, 'value') else int(item) for item in axis] - elif hasattr(axis, 'value'): + if hasattr(axis, "__iter__") and not isinstance(axis, str): + axis = [int(item.value) if hasattr(item, "value") else int(item) for item in axis] + elif hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + # Call torch.squeeze with appropriate parameters if axis is not None: return torch.squeeze(tensor_arg, dim=axis) else: return torch.squeeze(tensor_arg) - + def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) -> Any: """Convert tensor operations like concat, split, stack.""" if len(call.args) < 1: return f"<{op_name}_error: insufficient arguments>" - + # Convert arguments converted_args = [self.convert_expr(arg, args) for arg in call.args] - + if op_name == "concat": # torch.cat(tensors, dim=0) # In Relax, concat takes a tuple of tensors as first argument @@ -831,34 +840,34 @@ def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) - # Direct tensor arguments tensors = converted_args axis = 0 - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) return torch.cat(tensors, dim=axis) - + elif op_name == "split": # torch.split(tensor, split_size_or_sections, dim=0) tensor = converted_args[0] split_size = converted_args[1] if len(converted_args) > 1 else 1 axis = 0 - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + # Handle indices_or_sections parameter - if call.attrs and hasattr(call.attrs, 'indices_or_sections'): + if call.attrs and hasattr(call.attrs, "indices_or_sections"): indices_or_sections = call.attrs.indices_or_sections - if hasattr(indices_or_sections, 'value'): + if hasattr(indices_or_sections, "value"): indices_or_sections = int(indices_or_sections.value) elif isinstance(indices_or_sections, (int, float)): indices_or_sections = int(indices_or_sections) - + # If indices_or_sections is an integer, it means split into N equal parts if isinstance(indices_or_sections, int): total_size = tensor.shape[axis] @@ -869,7 +878,7 @@ def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) - return torch.split(tensor, indices_or_sections, dim=axis) else: return torch.split(tensor, split_size, dim=axis) - + elif op_name == "stack": # torch.stack(tensors, dim=0) if len(converted_args) == 1 and isinstance(converted_args[0], tuple): @@ -877,29 +886,31 @@ def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) - else: tensors = converted_args axis = 0 - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) return torch.stack(tensors, dim=axis) - + else: return f"<{op_name}_error: unsupported operation>" - + def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any: """Convert reshape operation.""" if len(call.args) < 2: return "" - + tensor_arg = self.convert_expr(call.args[0], args) shape_arg = call.args[1] - + # Convert shape argument to Python tuple if isinstance(shape_arg, relax.ShapeExpr): - if hasattr(shape_arg, 'values'): - shape = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in shape_arg.values) + if hasattr(shape_arg, "values"): + shape = tuple( + int(v.value) if hasattr(v, "value") else int(v) for v in shape_arg.values + ) else: shape = (int(shape_arg),) elif isinstance(shape_arg, relax.Constant): @@ -913,50 +924,49 @@ def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any: shape = tuple(int(v) for v in converted_shape) else: shape = (int(converted_shape),) - + return torch.reshape(tensor_arg, shape) - - + def _convert_permute_dims(self, call: relax.Call, args: List[Any]) -> Any: """Convert permute_dims operation.""" if len(call.args) < 1: return "" - + tensor_arg = self.convert_expr(call.args[0], args) - + # Extract axes from call.attrs - if call.attrs and hasattr(call.attrs, 'axes'): + if call.attrs and hasattr(call.attrs, "axes"): axes = call.attrs.axes # Handle TVM Array type - if hasattr(axes, '__iter__') and not isinstance(axes, str): + if hasattr(axes, "__iter__") and not isinstance(axes, str): # Convert TVM Array or Python list/tuple to tuple of ints - axes = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in axes) + axes = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in axes) elif isinstance(axes, (list, tuple)): axes = tuple(int(v) for v in axes) else: axes = (int(axes),) else: return "" - + return torch.permute(tensor_arg, axes) - + def _convert_take(self, call: relax.Call, args: List[Any]) -> Any: """Convert take operation.""" if len(call.args) < 2: return "" - + tensor_arg = self.convert_expr(call.args[0], args) indices_arg = self.convert_expr(call.args[1], args) - + # Extract axis from call.attrs axis = None - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + if axis is not None: # Use advanced indexing for specific axis if axis == 0: @@ -967,101 +977,103 @@ def _convert_take(self, call: relax.Call, args: List[Any]) -> Any: else: # No axis specified, use torch.take (flattens the tensor) return torch.take(tensor_arg, indices_arg) - + def _convert_flip(self, call: relax.Call, args: List[Any]) -> Any: """Convert flip operation.""" if len(call.args) < 1: return "" - + tensor_arg = self.convert_expr(call.args[0], args) - + # Extract axis from call.attrs axis = None - if call.attrs and hasattr(call.attrs, 'axis'): + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + if axis is not None: # Convert single axis to list for torch.flip dims = [axis] else: # Default: flip all dimensions dims = list(range(tensor_arg.dim())) - + return torch.flip(tensor_arg, dims=dims) - + def _convert_tile(self, call: relax.Call, args: List[Any]) -> Any: """Convert tile operation.""" if len(call.args) < 1: return "" - + tensor_arg = self.convert_expr(call.args[0], args) - + # Extract repeats from call.attrs - if call.attrs and hasattr(call.attrs, 'repeats'): + if call.attrs and hasattr(call.attrs, "repeats"): repeats = call.attrs.repeats # Handle TVM Array type - if hasattr(repeats, '__iter__') and not isinstance(repeats, str): - repeats = tuple(int(v.value) if hasattr(v, 'value') else int(v) for v in repeats) + if hasattr(repeats, "__iter__") and not isinstance(repeats, str): + repeats = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in repeats) elif isinstance(repeats, (list, tuple)): repeats = tuple(int(v) for v in repeats) else: repeats = (int(repeats),) else: return "" - + return torch.tile(tensor_arg, dims=repeats) - + def _convert_repeat(self, call: relax.Call, args: List[Any]) -> Any: """Convert repeat operation.""" if len(call.args) < 1: return "" - + tensor_arg = self.convert_expr(call.args[0], args) - + # Extract repeats and axis from call.attrs repeats = 1 axis = None - - if call.attrs and hasattr(call.attrs, 'repeats'): + + if call.attrs and hasattr(call.attrs, "repeats"): repeats = call.attrs.repeats - if hasattr(repeats, 'value'): + if hasattr(repeats, "value"): repeats = int(repeats.value) elif isinstance(repeats, (int, float)): repeats = int(repeats) - - if call.attrs and hasattr(call.attrs, 'axis'): + + if call.attrs and hasattr(call.attrs, "axis"): axis = call.attrs.axis - if hasattr(axis, 'value'): + if hasattr(axis, "value"): axis = int(axis.value) elif isinstance(axis, (int, float)): axis = int(axis) - + if axis is not None: return torch.repeat_interleave(tensor_arg, repeats=repeats, dim=axis) else: return torch.repeat_interleave(tensor_arg, repeats=repeats) - + def _convert_shape_expr(self, shape_expr: relax.ShapeExpr) -> Any: """Convert a Relax shape expression to Python equivalent.""" - if hasattr(shape_expr, 'values'): + if hasattr(shape_expr, "values"): return f"" return f"" -def convert_relax_to_pyfunc(ir_module: IRModule, relax_function_names: Union[str, List[str]]) -> IRModule: +def convert_relax_to_pyfunc( + ir_module: IRModule, relax_function_names: Union[str, List[str]] +) -> IRModule: """Convert Relax functions to Python functions. - + Args: ir_module: The IRModule containing Relax functions relax_function_names: Name(s) of Relax functions to convert - + Returns: IRModule with converted Python functions stored in pyfuncs - + Example: >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "my_function") >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, ["func1", "func2"]) diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py index 6dc72adafdb6..8133be38a0b3 100644 --- a/tests/python/relax/test_relax_to_pyfunc_converter.py +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -75,10 +75,10 @@ def with_call_tir( return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32")) @R.function - def with_call_dps_packed( - x: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - return R.call_dps_packed("my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32")) + def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.call_dps_packed( + "my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32") + ) @R.function def complex_function( @@ -98,8 +98,7 @@ def symbolic_add( @R.function def symbolic_matmul( - x: R.Tensor(("batch", "m", "k"), "float32"), - y: R.Tensor(("batch", "k", "n"), "float32") + x: R.Tensor(("batch", "m", "k"), "float32"), y: R.Tensor(("batch", "k", "n"), "float32") ) -> R.Tensor(("batch", "m", "n"), "float32"): return R.matmul(x, y) @@ -143,7 +142,9 @@ def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 2, 3), return R.permute_dims(x, axes=[2, 0, 1]) @R.function - def test_concat(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"): + def test_concat( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((4, 3), "float32"): return R.concat((x, y), axis=0) @R.function @@ -151,11 +152,15 @@ def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple: return R.split(x, indices_or_sections=2, axis=0) @R.function - def test_stack(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 2, 3), "float32"): + def test_stack( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 2, 3), "float32"): return R.stack((x, y), axis=1) @R.function - def test_take(x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64")) -> R.Tensor((2,), "float32"): + def test_take( + x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64") + ) -> R.Tensor((2,), "float32"): return R.take(x, indices, axis=0) @R.function @@ -189,10 +194,11 @@ def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float3 def create_mock_packed_function(): """Create a mock packed function for testing.""" + def mock_softmax(x, axis): """Mock softmax function that just returns the input.""" return x - + # Register the function globally tvm.register_func("my_softmax", mock_softmax) @@ -210,29 +216,29 @@ def setup_class(cls): def test_basic_operations(self): """Test basic arithmetic operations.""" converted_ir_mod = self.converter.convert(["simple_add", "with_relu"]) - + # Test simple_add x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['simple_add'](x, y) + + result = converted_ir_mod.pyfuncs["simple_add"](x, y) expected = torch.add(x, y) assert torch.allclose(result, expected) - + # Test with_relu x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) - result = converted_ir_mod.pyfuncs['with_relu'](x_neg) + result = converted_ir_mod.pyfuncs["with_relu"](x_neg) expected = torch.nn.functional.relu(x_neg) assert torch.allclose(result, expected) def test_call_tir(self): """Test call_tir functionality with DLPack conversion.""" converted_ir_mod = self.converter.convert(["with_call_tir"]) - + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['with_call_tir'](x, y) + + result = converted_ir_mod.pyfuncs["with_call_tir"](x, y) expected = torch.add(x, y) assert torch.allclose(result, expected) assert result.shape == expected.shape @@ -240,54 +246,54 @@ def test_call_tir(self): def test_call_dps_packed(self): """Test call_dps_packed functionality.""" converted_ir_mod = self.converter.convert(["with_call_dps_packed"]) - + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['with_call_dps_packed'](x) + + result = converted_ir_mod.pyfuncs["with_call_dps_packed"](x) expected = x assert torch.allclose(result, expected) def test_complex_function(self): """Test complex function with multiple operations.""" converted_ir_mod = self.converter.convert(["complex_function"]) - + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['complex_function'](x, y) - + + result = converted_ir_mod.pyfuncs["complex_function"](x, y) + # Expected: relu(add(relu(add(x, y)), y)) step1 = torch.add(x, y) step2 = torch.nn.functional.relu(step1) step3 = torch.add(step2, y) # TIR call expected = torch.nn.functional.relu(step3) - + assert torch.allclose(result, expected) def test_symbolic_shapes(self): """Test symbolic shape handling.""" - converted_ir_mod = self.converter.convert([ - "symbolic_add", "symbolic_matmul", "symbolic_expand_dims" - ]) - + converted_ir_mod = self.converter.convert( + ["symbolic_add", "symbolic_matmul", "symbolic_expand_dims"] + ) + # Test symbolic_add x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) - result = converted_ir_mod.pyfuncs['symbolic_add'](x, y) + result = converted_ir_mod.pyfuncs["symbolic_add"](x, y) expected = torch.add(x, y) assert torch.allclose(result, expected) - + # Test symbolic_matmul x = torch.randn(2, 3, 4, dtype=torch.float32) # (batch=2, m=3, k=4) y = torch.randn(2, 4, 5, dtype=torch.float32) # (batch=2, k=4, n=5) - result = converted_ir_mod.pyfuncs['symbolic_matmul'](x, y) + result = converted_ir_mod.pyfuncs["symbolic_matmul"](x, y) expected = torch.matmul(x, y) assert torch.allclose(result, expected) assert result.shape == (2, 3, 5) - + # Test symbolic_expand_dims x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - result = converted_ir_mod.pyfuncs['symbolic_expand_dims'](x) + result = converted_ir_mod.pyfuncs["symbolic_expand_dims"](x) expected = torch.unsqueeze(x, dim=2) assert torch.allclose(result, expected) assert result.shape == (2, 2, 1) @@ -295,194 +301,224 @@ def test_symbolic_shapes(self): def test_multi_operations(self): """Test multiple operations in sequence.""" converted_ir_mod = self.converter.convert(["multi_ops"]) - - x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], dtype=torch.float32) - y = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['multi_ops'](x, y) - + + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + y = torch.tensor( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32 + ) + + result = converted_ir_mod.pyfuncs["multi_ops"](x, y) + # Expected: maximum(power(multiply(add(x, y), y), 2), x) step1 = torch.add(x, y) step2 = torch.mul(step1, y) step3 = torch.pow(step2, 2.0) expected = torch.maximum(step3, x) - + assert torch.allclose(result, expected) def test_reduction_operations(self): """Test reduction operations.""" converted_ir_mod = self.converter.convert(["reduction_ops"]) - + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['reduction_ops'](x) - + + result = converted_ir_mod.pyfuncs["reduction_ops"](x) + # Expected: sum(x) + mean(x) + max(x) expected = torch.sum(x) + torch.mean(x) + torch.max(x) - + assert torch.allclose(result, expected) assert result.shape == () def test_comparison_operations(self): """Test comparison operations.""" converted_ir_mod = self.converter.convert(["comparison_ops"]) - + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32) - - result = converted_ir_mod.pyfuncs['comparison_ops'](x, y) - + + result = converted_ir_mod.pyfuncs["comparison_ops"](x, y) + # Expected: logical_and(equal(x, y), greater(x, y)) eq_val = torch.eq(x, y) gt_val = torch.gt(x, y) expected = torch.logical_and(eq_val, gt_val) - + assert torch.allclose(result, expected) assert result.dtype == torch.bool def test_operator_mapping_completeness(self): """Test that operator mapping is comprehensive.""" operator_map = RelaxToPyFuncConverter._get_relax_to_pytorch_operator_map() - + # Check that we have a good number of operators assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}" - + # Check key operator categories - binary_ops = [op for op in operator_map.keys() if op.startswith("relax.") and not op.startswith("relax.nn.")] + binary_ops = [ + op + for op in operator_map.keys() + if op.startswith("relax.") and not op.startswith("relax.nn.") + ] nn_ops = [op for op in operator_map.keys() if op.startswith("relax.nn.")] - + assert len(binary_ops) > 20, f"Expected >20 binary ops, got {len(binary_ops)}" assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}" - + # Check specific important operators important_ops = [ - "relax.add", "relax.multiply", "relax.nn.relu", "relax.nn.softmax", - "relax.matmul", "relax.reshape", "relax.sum", "relax.mean" + "relax.add", + "relax.multiply", + "relax.nn.relu", + "relax.nn.softmax", + "relax.matmul", + "relax.reshape", + "relax.sum", + "relax.mean", ] - + for op in important_ops: assert op in operator_map, f"Missing important operator: {op}" def test_error_handling(self): """Test error handling for invalid inputs.""" converted_ir_mod = self.converter.convert(["simple_add"]) - + # Test with wrong number of arguments x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) - + with pytest.raises(ValueError, match="Expected 2 arguments"): - converted_ir_mod.pyfuncs['simple_add'](x) # Missing second argument - + converted_ir_mod.pyfuncs["simple_add"](x) # Missing second argument + # Test with incompatible shapes - this should raise a RuntimeError x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) y = torch.tensor([1.0, 2.0], dtype=torch.float32) # Different shape - + # This should raise a RuntimeError because shapes don't match with pytest.raises(RuntimeError, match="The size of tensor a"): - converted_ir_mod.pyfuncs['simple_add'](x, y) + converted_ir_mod.pyfuncs["simple_add"](x, y) def test_conversion_metadata(self): """Test that conversion preserves metadata correctly.""" converted_ir_mod = self.converter.convert(["simple_add"]) - + # Check that pyfuncs attribute exists - assert hasattr(converted_ir_mod, 'pyfuncs') - assert 'simple_add' in converted_ir_mod.pyfuncs - + assert hasattr(converted_ir_mod, "pyfuncs") + assert "simple_add" in converted_ir_mod.pyfuncs + # Check function metadata - pyfunc = converted_ir_mod.pyfuncs['simple_add'] - assert hasattr(pyfunc, '__name__') - assert hasattr(pyfunc, '__doc__') - assert pyfunc.__name__ == 'simple_add' + pyfunc = converted_ir_mod.pyfuncs["simple_add"] + assert hasattr(pyfunc, "__name__") + assert hasattr(pyfunc, "__doc__") + assert pyfunc.__name__ == "simple_add" def test_tensor_operations(self): """Test tensor manipulation operations.""" - converted_ir_mod = self.converter.convert([ - "test_reshape", "test_permute_dims", "test_concat", "test_split", - "test_stack", "test_take", "test_flip", "test_tile", "test_repeat", - "test_expand_dims", "test_squeeze", "test_sum_with_axis", "test_max_with_axis" - ]) - + converted_ir_mod = self.converter.convert( + [ + "test_reshape", + "test_permute_dims", + "test_concat", + "test_split", + "test_stack", + "test_take", + "test_flip", + "test_tile", + "test_repeat", + "test_expand_dims", + "test_squeeze", + "test_sum_with_axis", + "test_max_with_axis", + ] + ) + # Test reshape x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result1 = converted_ir_mod.pyfuncs["test_reshape"](x1) expected1 = torch.reshape(x1, (6,)) assert torch.allclose(result1, expected1), "Reshape operation failed" - + # Test permute_dims x2 = torch.randn(2, 3, 4) result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2) expected2 = torch.permute(x2, (2, 0, 1)) assert torch.allclose(result2, expected2), "Permute_dims operation failed" - + # Test concat x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3) expected3 = torch.cat([x3, y3], dim=0) assert torch.allclose(result3, expected3), "Concat operation failed" - + # Test split - x4 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + x4 = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=torch.float32, + ) result4 = converted_ir_mod.pyfuncs["test_split"](x4) expected4 = torch.split(x4, 2, dim=0) assert len(result4) == len(expected4), "Split operation failed - wrong number of tensors" for r, e in zip(result4, expected4): assert torch.allclose(r, e), "Split operation failed - tensor mismatch" - + # Test stack x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5) expected5 = torch.stack([x5, y5], dim=1) assert torch.allclose(result5, expected5), "Stack operation failed" - + # Test take - x6 = torch.tensor([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0]], dtype=torch.float32) + x6 = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) indices = torch.tensor([0, 2], dtype=torch.int64) result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices) expected6 = x6[indices] assert torch.allclose(result6, expected6), "Take operation failed" - + # Test flip x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result7 = converted_ir_mod.pyfuncs["test_flip"](x7) expected7 = torch.flip(x7, dims=[1]) assert torch.allclose(result7, expected7), "Flip operation failed" - + # Test tile x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result8 = converted_ir_mod.pyfuncs["test_tile"](x8) expected8 = torch.tile(x8, (2, 2)) assert torch.allclose(result8, expected8), "Tile operation failed" - + # Test repeat x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result9 = converted_ir_mod.pyfuncs["test_repeat"](x9) expected9 = torch.repeat_interleave(x9, repeats=2, dim=0) assert torch.allclose(result9, expected9), "Repeat operation failed" - + # Test expand_dims x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10) expected10 = torch.unsqueeze(x10, dim=2) assert torch.allclose(result10, expected10), "Expand_dims operation failed" - + # Test squeeze x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float32) result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11) expected11 = torch.squeeze(x11, dim=2) assert torch.allclose(result11, expected11), "Squeeze operation failed" - + # Test sum with axis x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12) expected12 = torch.sum(x12, dim=0) assert torch.allclose(result12, expected12), "Sum with axis operation failed" - + # Test max with axis x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13) @@ -533,36 +569,52 @@ def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): # Comparison operations not covered in ComprehensiveTestModule @R.function - def test_less(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "bool"): + def test_less( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): return R.less(x, y) @R.function - def test_not_equal(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "bool"): + def test_not_equal( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): return R.not_equal(x, y) # Binary operations not covered in ComprehensiveTestModule @R.function - def test_multiply(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_multiply( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.multiply(x, y) @R.function - def test_divide(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_divide( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.divide(x, y) @R.function - def test_power(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_power( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.power(x, y) @R.function - def test_maximum(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_maximum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.maximum(x, y) @R.function - def test_minimum(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_minimum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.minimum(x, y) @R.function - def test_subtract(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + def test_subtract( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): return R.subtract(x, y) # Additional tensor operations with different parameters @@ -612,198 +664,199 @@ def setup_class(cls): def test_unary_operations(self): """Test unary operations.""" - converted_ir_mod = self.converter.convert([ - "test_abs", "test_neg", "test_exp", "test_log", "test_sqrt" - ]) - + converted_ir_mod = self.converter.convert( + ["test_abs", "test_neg", "test_exp", "test_log", "test_sqrt"] + ) + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) - + # Test abs - result = converted_ir_mod.pyfuncs['test_abs'](x) + result = converted_ir_mod.pyfuncs["test_abs"](x) expected = torch.abs(x) assert torch.allclose(result, expected) - + # Test negative - result = converted_ir_mod.pyfuncs['test_neg'](x) + result = converted_ir_mod.pyfuncs["test_neg"](x) expected = torch.neg(x) assert torch.allclose(result, expected) - + # Test exp - result = converted_ir_mod.pyfuncs['test_exp'](x) + result = converted_ir_mod.pyfuncs["test_exp"](x) expected = torch.exp(x) assert torch.allclose(result, expected) - + # Test log (with positive values) x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) - result = converted_ir_mod.pyfuncs['test_log'](x_pos) + result = converted_ir_mod.pyfuncs["test_log"](x_pos) expected = torch.log(x_pos) assert torch.allclose(result, expected) - + # Test sqrt - result = converted_ir_mod.pyfuncs['test_sqrt'](x_pos) + result = converted_ir_mod.pyfuncs["test_sqrt"](x_pos) expected = torch.sqrt(x_pos) assert torch.allclose(result, expected) def test_trigonometric_operations(self): """Test trigonometric operations.""" - converted_ir_mod = self.converter.convert([ - "test_sin", "test_cos", "test_tanh", "test_sigmoid" - ]) - + converted_ir_mod = self.converter.convert( + ["test_sin", "test_cos", "test_tanh", "test_sigmoid"] + ) + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32) - + # Test sin - result = converted_ir_mod.pyfuncs['test_sin'](x) + result = converted_ir_mod.pyfuncs["test_sin"](x) expected = torch.sin(x) assert torch.allclose(result, expected) - + # Test cos - result = converted_ir_mod.pyfuncs['test_cos'](x) + result = converted_ir_mod.pyfuncs["test_cos"](x) expected = torch.cos(x) assert torch.allclose(result, expected) - + # Test tanh - result = converted_ir_mod.pyfuncs['test_tanh'](x) + result = converted_ir_mod.pyfuncs["test_tanh"](x) expected = torch.tanh(x) assert torch.allclose(result, expected) - + # Test sigmoid - result = converted_ir_mod.pyfuncs['test_sigmoid'](x) + result = converted_ir_mod.pyfuncs["test_sigmoid"](x) expected = torch.sigmoid(x) assert torch.allclose(result, expected) def test_comparison_operations(self): """Test comparison operations not covered in ComprehensiveTestModule.""" - converted_ir_mod = self.converter.convert([ - "test_less", "test_not_equal" - ]) - + converted_ir_mod = self.converter.convert(["test_less", "test_not_equal"]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) - + # Test less - result = converted_ir_mod.pyfuncs['test_less'](x, y) + result = converted_ir_mod.pyfuncs["test_less"](x, y) expected = torch.lt(x, y) assert torch.equal(result, expected) - + # Test not equal - result = converted_ir_mod.pyfuncs['test_not_equal'](x, y) + result = converted_ir_mod.pyfuncs["test_not_equal"](x, y) expected = torch.ne(x, y) assert torch.equal(result, expected) def test_binary_operations(self): """Test binary operations.""" - converted_ir_mod = self.converter.convert([ - "test_multiply", "test_divide", "test_power", "test_maximum", "test_minimum", "test_subtract" - ]) - + converted_ir_mod = self.converter.convert( + [ + "test_multiply", + "test_divide", + "test_power", + "test_maximum", + "test_minimum", + "test_subtract", + ] + ) + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) - + # Test multiply - result = converted_ir_mod.pyfuncs['test_multiply'](x, y) + result = converted_ir_mod.pyfuncs["test_multiply"](x, y) expected = torch.mul(x, y) assert torch.allclose(result, expected) - + # Test divide - result = converted_ir_mod.pyfuncs['test_divide'](x, y) + result = converted_ir_mod.pyfuncs["test_divide"](x, y) expected = torch.div(x, y) assert torch.allclose(result, expected) - + # Test power - result = converted_ir_mod.pyfuncs['test_power'](x, y) + result = converted_ir_mod.pyfuncs["test_power"](x, y) expected = torch.pow(x, y) assert torch.allclose(result, expected) - + # Test maximum - result = converted_ir_mod.pyfuncs['test_maximum'](x, y) + result = converted_ir_mod.pyfuncs["test_maximum"](x, y) expected = torch.maximum(x, y) assert torch.allclose(result, expected) - + # Test minimum - result = converted_ir_mod.pyfuncs['test_minimum'](x, y) + result = converted_ir_mod.pyfuncs["test_minimum"](x, y) expected = torch.minimum(x, y) assert torch.allclose(result, expected) - + # Test subtract - result = converted_ir_mod.pyfuncs['test_subtract'](x, y) + result = converted_ir_mod.pyfuncs["test_subtract"](x, y) expected = torch.sub(x, y) assert torch.allclose(result, expected) def test_tensor_operations(self): """Test tensor operations not covered in ComprehensiveTestModule.""" - converted_ir_mod = self.converter.convert([ - "test_transpose_2d" - ]) - + converted_ir_mod = self.converter.convert(["test_transpose_2d"]) + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) - + # Test transpose - result = converted_ir_mod.pyfuncs['test_transpose_2d'](x) + result = converted_ir_mod.pyfuncs["test_transpose_2d"](x) expected = torch.transpose(x, 0, 1) assert torch.allclose(result, expected) assert result.shape == (4, 2) def test_reduction_operations(self): """Test reduction operations not covered in ComprehensiveTestModule.""" - converted_ir_mod = self.converter.convert([ - "test_mean_axis", "test_min_axis" - ]) - + converted_ir_mod = self.converter.convert(["test_mean_axis", "test_min_axis"]) + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) - + # Test mean - result = converted_ir_mod.pyfuncs['test_mean_axis'](x) + result = converted_ir_mod.pyfuncs["test_mean_axis"](x) expected = torch.mean(x, dim=0) assert torch.allclose(result, expected) assert result.shape == (3,) - + # Test min - result = converted_ir_mod.pyfuncs['test_min_axis'](x) + result = converted_ir_mod.pyfuncs["test_min_axis"](x) expected = torch.min(x, dim=0)[0] assert torch.allclose(result, expected) assert result.shape == (3,) def test_neural_network_operations(self): """Test neural network operations not covered in ComprehensiveTestModule.""" - converted_ir_mod = self.converter.convert([ - "test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn" - ]) - - x = torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32) - + converted_ir_mod = self.converter.convert( + ["test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn"] + ) + + x = torch.tensor( + [[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32 + ) + # Test gelu - result = converted_ir_mod.pyfuncs['test_gelu_nn'](x[0]) + result = converted_ir_mod.pyfuncs["test_gelu_nn"](x[0]) expected = F.gelu(x[0]) assert torch.allclose(result, expected) - + # Test softmax - result = converted_ir_mod.pyfuncs['test_softmax_nn'](x) + result = converted_ir_mod.pyfuncs["test_softmax_nn"](x) expected = F.softmax(x, dim=1) assert torch.allclose(result, expected) - + # Test log_softmax - result = converted_ir_mod.pyfuncs['test_log_softmax_nn'](x) + result = converted_ir_mod.pyfuncs["test_log_softmax_nn"](x) expected = F.log_softmax(x, dim=1) assert torch.allclose(result, expected) def test_advanced_tensor_operations(self): """Test advanced tensor operations with different parameters.""" - converted_ir_mod = self.converter.convert([ - "test_tile_dims", "test_repeat_axis" - ]) - + converted_ir_mod = self.converter.convert(["test_tile_dims", "test_repeat_axis"]) + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) - + # Test tile with different dimensions - result = converted_ir_mod.pyfuncs['test_tile_dims'](x) + result = converted_ir_mod.pyfuncs["test_tile_dims"](x) expected = torch.tile(x, (2, 3)) assert torch.allclose(result, expected) assert result.shape == (4, 12) - + # Test repeat with different parameters x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) - result = converted_ir_mod.pyfuncs['test_repeat_axis'](x_1d) + result = converted_ir_mod.pyfuncs["test_repeat_axis"](x_1d) expected = torch.repeat_interleave(x_1d, repeats=2, dim=0) assert torch.allclose(result, expected) assert result.shape == (6,) From 2996cbd7365583b90696200d2941c66be11aafaa Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 4 Sep 2025 23:20:23 +0800 Subject: [PATCH 4/4] finish4 --- python/tvm/relax/relax_to_pyfunc_converter.py | 128 ++++++++++-------- .../relax/test_relax_to_pyfunc_converter.py | 2 +- 2 files changed, 76 insertions(+), 54 deletions(-) diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py index 4306a6929fc3..3de27d78c863 100644 --- a/python/tvm/relax/relax_to_pyfunc_converter.py +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -22,11 +22,12 @@ from typing import Any, Dict, List, Union +import torch +import torch.nn.functional as F + import tvm from tvm import relax from tvm.ir import IRModule, Op -import torch -import torch.nn.functional as F class RelaxToPyFuncConverter: @@ -44,7 +45,7 @@ def __init__(self, ir_module: IRModule): ir_module: The IRModule containing Relax functions to convert """ self.ir_module = ir_module - self.operator_map = self._get_relax_to_pytorch_operator_map() + self.operator_map = self._get_op_map() # Cache for RelaxExpressionConverter instances to avoid recreating them self._converter_cache = {} # Cache for operator mappings to avoid repeated lookups @@ -78,9 +79,9 @@ def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: # Get Relax function names from IRModule relax_func_names = [] - for gv, func in self.ir_module.functions_items(): + for global_var, func in self.ir_module.functions_items(): if isinstance(func, relax.Function): - relax_func_names.append(gv.name_hint) + relax_func_names.append(global_var.name_hint) # Convert each Relax function for func_name in relax_function_names: @@ -89,8 +90,8 @@ def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: # Get the Relax function relax_func = None - for gv, func in self.ir_module.functions_items(): - if gv.name_hint == func_name and isinstance(func, relax.Function): + for global_var, func in self.ir_module.functions_items(): + if global_var.name_hint == func_name and isinstance(func, relax.Function): relax_func = func break @@ -98,22 +99,20 @@ def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: raise ValueError(f"Could not find Relax function '{func_name}'") # Convert to Python function - py_func = self._convert_relax_function_to_python(relax_func, func_name) + py_func = self._convert_relax_func_to_python(relax_func, func_name) # Store in pyfuncs new_ir_mod.pyfuncs[func_name] = py_func return new_ir_mod - def _convert_relax_function_to_python( - self, relax_func: relax.Function, func_name: str - ) -> callable: + def _convert_relax_func_to_python(self, relax_func: relax.Function, func_name: str) -> callable: """Convert a single Relax function to a Python function with caching.""" # Get function parameters params = relax_func.params # Create the Python function - def converted_function(*args, **kwargs): + def converted_function(*args, **_kwargs): """Converted Python function from Relax function.""" # Handle arguments if len(args) != len(params): @@ -137,7 +136,7 @@ def converted_function(*args, **kwargs): return converted_function @staticmethod - def _get_relax_to_pytorch_operator_map() -> Dict[str, str]: + def _get_op_map() -> Dict[str, str]: """Get the mapping from Relax operators to PyTorch operators.""" return { # Binary operations @@ -480,47 +479,70 @@ def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: return self._convert_repeat(call, args) # Handle special cases for PyTorch operations elif pytorch_op.startswith("F."): - # Neural network function - func_name = pytorch_op[2:] # Remove "F." prefix - func = getattr(F, func_name) - - # Special handling for functions that need dim parameter - if func_name in ["softmax", "log_softmax"]: - # Extract axis from call.attrs and convert to dim - axis = None - if call.attrs and hasattr(call.attrs, "axis"): - axis = call.attrs.axis - if hasattr(axis, "value"): - axis = int(axis.value) - elif isinstance(axis, (int, float)): - axis = int(axis) - - if axis is not None: - return func(call_args[0], dim=axis) - else: - # Default to last dimension if no axis specified - return func(call_args[0], dim=-1) - else: - return func(*call_args) + return self._handle_functional_operation(pytorch_op, call, call_args) elif pytorch_op.startswith("torch."): # Regular PyTorch operation func_name = pytorch_op[6:] # Remove "torch." prefix func = getattr(torch, func_name) return func(*call_args) else: - # Direct function reference - return eval(pytorch_op)(*call_args) - except Exception as e: + # Direct function reference - use getattr for safer access + if pytorch_op.startswith("torch."): + module = torch + func_name = pytorch_op[6:] # Remove "torch." prefix + elif pytorch_op.startswith("F."): + module = F + func_name = pytorch_op[2:] # Remove "F." prefix + else: + return ( + f"" + ) + + func = getattr(module, func_name, None) + if func is None: + return ( + f"" + ) + return func(*call_args) + except (AttributeError, TypeError, ValueError) as error: # This allows the test framework to catch and handle the errors appropriately if pytorch_op.startswith("torch.") or pytorch_op.startswith("F."): - raise e - else: - # Fallback to string representation for non-PyTorch operations - return f"" + raise error + # Fallback to string representation for non-PyTorch operations + return f"" else: # Unknown operator return f"" + def _handle_functional_operation( + self, pytorch_op: str, call: relax.Call, call_args: List[Any] + ) -> Any: + """Handle PyTorch functional operations with special parameter handling.""" + # Neural network function + func_name = pytorch_op[2:] # Remove "F." prefix + func = getattr(F, func_name) + + # Special handling for functions that need dim parameter + if func_name in ["softmax", "log_softmax"]: + # Extract axis from call.attrs and convert to dim + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return func(call_args[0], dim=axis) + else: + # Default to last dimension if no axis specified + return func(call_args[0], dim=-1) + else: + return func(*call_args) + def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> Any: """Convert an external function call.""" func_name = call.op.global_symbol @@ -558,15 +580,15 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: tir_function = None if self.ir_module: # Look for the TIR function in the current IRModule - for gv, func in self.ir_module.functions.items(): - if gv.name_hint == func_name and hasattr(func, "body"): + for global_var, func in self.ir_module.functions.items(): + if global_var.name_hint == func_name and hasattr(func, "body"): try: # Compile the TIR function target = tvm.target.Target("llvm") with tvm.target.Target(target): tir_function = tvm.compile(func, target=target) break - except Exception as compile_e: + except (RuntimeError, ValueError, TypeError) as compile_e: print( f"Warning: Failed to compile TIR function {func_name}: {compile_e}" ) @@ -615,15 +637,15 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: # Convert result back to PyTorch tensor via DLPack return torch.from_dlpack(output_tensor.to_dlpack()) - except Exception as e: - return f"" + except (RuntimeError, ValueError, TypeError) as error: + return f"" def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: """Convert call_dps_packed to Python equivalent with DLPack conversion.""" # Extract packed function name and arguments packed_func = call.args[0] packed_args = call.args[1] if len(call.args) > 1 else [] - out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + _out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None # Get function name if isinstance(packed_func, relax.GlobalVar): @@ -661,8 +683,8 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: else: return result - except Exception as e: - return f"" + except (RuntimeError, ValueError, TypeError) as error: + return f"" def _convert_constant(self, const: relax.Constant) -> Any: """Convert a Relax constant to Python equivalent.""" @@ -718,7 +740,7 @@ def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any: def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: """Convert expand_dims to torch.unsqueeze with proper axis handling.""" if len(call.args) < 1: - return f"" + return "" # Convert the tensor argument tensor_arg = self.convert_expr(call.args[0], args) @@ -740,7 +762,7 @@ def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: axis = int(axis) if axis is None: - return f"" + return "" # Use torch.unsqueeze with the correct axis return torch.unsqueeze(tensor_arg, dim=axis) @@ -799,7 +821,7 @@ def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any: """Convert squeeze to torch.squeeze with proper axis handling.""" if len(call.args) < 1: - return f"" + return "" # Convert the tensor argument tensor_arg = self.convert_expr(call.args[0], args) diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py index 8133be38a0b3..6dce3093156f 100644 --- a/tests/python/relax/test_relax_to_pyfunc_converter.py +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -353,7 +353,7 @@ def test_comparison_operations(self): def test_operator_mapping_completeness(self): """Test that operator mapping is comprehensive.""" - operator_map = RelaxToPyFuncConverter._get_relax_to_pytorch_operator_map() + operator_map = RelaxToPyFuncConverter._get_op_map() # Check that we have a good number of operators assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}"