From b922a52bd016ff16951edaee4a091433c8983734 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Sat, 6 Dec 2025 12:07:55 +0800 Subject: [PATCH] Enhance scale_factor handling in interpolation --- .../torch/exported_program_translator.py | 18 +++---- .../test_frontend_from_exported_program.py | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2ec61796c31a..641e16f599df 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -337,11 +337,11 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: ) else: - # TODO figure out why pytorch export passes a list such as - # [scale_factor,scale_factor] instead of just an int for - # scale_factor. Using first element for now + # PyTorch export passes scale_factor as either a scalar or a list/tuple + # (e.g., [2.0, 3.0] for different H and W scaling). + # Pass it as-is to _upsample_impl which handles both cases correctly. scale_factor = ( - node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) + node.args[2] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1) ) align_corners = ( node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None) @@ -364,11 +364,11 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: if size is not None: scale_factor = None else: - scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) - if isinstance(scale_arg, (list, tuple)): - scale_factor = scale_arg[0] - else: - scale_factor = scale_arg + # PyTorch export passes scale_factor as either a scalar or a list/tuple. + # Pass it as-is to _upsample_impl which handles both cases correctly. + scale_factor = ( + node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) + ) return self._upsample_impl( x, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 010bd026a8ba..68567e1fc859 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -8542,5 +8542,56 @@ def main( verify_model(GridSample(), example_args, {}, expected) +def test_upsample_nearest2d(): + class UpsampleNearest2dScale(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest") + + class UpsampleNearest2dSize(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(20, 20), mode="nearest") + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + @tvm.script.ir_module + class expected_scale: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 20), dtype="float32") = R.image.resize2d( + input_1, + size=(20, 20), + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + ) + gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_size: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 20), dtype="float32") = R.image.resize2d( + input_1, + size=(20, 20), + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + ) + gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(UpsampleNearest2dScale(), example_args, {}, expected_scale) + verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size) + + if __name__ == "__main__": tvm.testing.main()