From d0d7036347e00aa3e6100a7e2a4865cd9857c413 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:56:31 +0800 Subject: [PATCH] Add NHWC layout support --- .../tvm/relax/frontend/torch/fx_translator.py | 36 ++++-- tests/python/relax/test_frontend_from_fx.py | 115 ++++++++++++++++++ 2 files changed, 144 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 9c2d53a68581..8b1f5de36b50 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -33,11 +33,12 @@ class TorchFXImporter(BaseFXGraphImporter): import torch # type: ignore from torch import fx - def __init__(self) -> None: + def __init__(self, default_image_layout: str = "NCHW") -> None: import torch # type: ignore super().__init__() self.named_modules: Dict[str, torch.Module] = None + self.default_image_layout = default_image_layout ########## Utilities ########## @@ -480,7 +481,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, # recompute_scale_factor=None, antialias=False) - # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout data = self.env[node.args[0]] size = ( node.args[1] @@ -523,13 +523,26 @@ def _interpolate(self, node: fx.Node) -> relax.Var: if size is None: shape = self.shape_of(data) assert isinstance(shape, relax.ShapeExpr) + # Determine spatial dimension indices based on layout + # NCHW: spatial dims are [2, 3, ...] (skip batch and channel) + # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel) + if self.default_image_layout == "NHWC": + spatial_start = 1 + spatial_end = len(shape) - 1 + else: # NCHW or other layouts + spatial_start = 2 + spatial_end = len(shape) + if isinstance(scale_factor, tuple): - assert len(scale_factor) == len(shape) - 2 + assert len(scale_factor) == spatial_end - spatial_start size = tuple( - int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + int(shape[i].value * scale_factor[i - spatial_start]) + for i in range(spatial_start, spatial_end) ) else: - size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + size = tuple( + int(shape[i].value * scale_factor) for i in range(spatial_start, spatial_end) + ) if method.startswith("nearest"): method = "nearest_neighbor" @@ -545,7 +558,11 @@ def _interpolate(self, node: fx.Node) -> relax.Var: return self.block_builder.emit( relax.op.image.resize2d( - data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + data, + size, + layout=self.default_image_layout, + method=method, + coordinate_transformation_mode=coord_trans, ) ) @@ -1150,6 +1167,7 @@ def from_fx( unwrap_unit_return_tuple: bool = False, no_bind_return_tuple: bool = False, custom_convert_map: dict = None, + default_image_layout: str = "NCHW", ) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program @@ -1175,6 +1193,10 @@ def from_fx( custom_convert_map : Dictionary of str to Relax op A custom op conversion map in the same format as TorchFXImporter.convert_map + default_image_layout : str + The default layout for image operations (e.g., "NCHW" or "NHWC"). + Default is "NCHW" which is the standard PyTorch layout. + Returns ------- output : tvm.IRModule @@ -1242,7 +1264,7 @@ def forward(self, input): to print out the tabular representation of the PyTorch module, and then check the placeholder rows in the beginning of the tabular. """ - return TorchFXImporter().from_fx( + return TorchFXImporter(default_image_layout=default_image_layout).from_fx( model, input_info, keep_params_as_input, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index de30af01ee01..b7aeea6687e8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3670,6 +3670,121 @@ def main( verify_model(Interpolate4(), input_info, {}, expected4) +def test_interpolate_nhwc_layout(): + # First verify backward compatibility - default should still be NCHW + input_info_nchw = [([1, 3, 10, 10], "float32")] + + class InterpolateDefault(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (5, 5)) + + @tvm.script.ir_module + class expected_default_nchw: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + # Verify default behavior (no default_image_layout parameter) uses NCHW + graph_model_default = fx.symbolic_trace(InterpolateDefault()) + with torch.no_grad(): + mod_default = from_fx(graph_model_default, input_info_nchw) + tvm.ir.assert_structural_equal(mod_default, expected_default_nchw) + + # Now test NHWC layout + input_info = [([1, 10, 10, 3], "float32")] + + class InterpolateNHWC(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (5, 5)) + + @tvm.script.ir_module + class expected_nhwc: + @R.function + def main( + input_1: R.Tensor((1, 10, 10, 3), dtype="float32") + ) -> R.Tensor((1, 5, 5, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 5, 5, 3), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NHWC", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 5, 5, 3), dtype="float32") = lv + R.output(gv) + return gv + + # Test with NHWC layout + graph_model = fx.symbolic_trace(InterpolateNHWC()) + with torch.no_grad(): + mod = from_fx(graph_model, input_info, default_image_layout="NHWC") + tvm.ir.assert_structural_equal(mod, expected_nhwc) + + # Test with bilinear interpolation and NHWC layout + class InterpolateNHWC2(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, size=None, scale_factor=2.0, mode="bilinear", align_corners=False + ) + + @tvm.script.ir_module + class expected_nhwc2: + @R.function + def main( + input_1: R.Tensor((1, 10, 10, 3), dtype="float32") + ) -> R.Tensor((1, 20, 20, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 20, 20, 3), dtype="float32") = R.image.resize2d( + input_1, + (20, 20), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 20, 20, 3), dtype="float32") = lv + R.output(gv) + return gv + + graph_model2 = fx.symbolic_trace(InterpolateNHWC2()) + with torch.no_grad(): + mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC") + tvm.ir.assert_structural_equal(mod2, expected_nhwc2) + + def test_addmm(): input_info = [ ([10, 10], "float32"),