Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class TorchFXImporter(BaseFXGraphImporter):
import torch # type: ignore
from torch import fx

def __init__(self) -> None:
def __init__(self, default_image_layout: str = "NCHW") -> None:
import torch # type: ignore

super().__init__()
self.named_modules: Dict[str, torch.Module] = None
self.default_image_layout = default_image_layout

########## Utilities ##########

Expand Down Expand Up @@ -480,7 +481,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var:
# torch.nn.functional.interpolate(
# input, size=None, scale_factor=None, mode='nearest', align_corners=None,
# recompute_scale_factor=None, antialias=False)
# (TODO) this is a temporary implementation for interpolate that only considers NCHW layout
data = self.env[node.args[0]]
size = (
node.args[1]
Expand Down Expand Up @@ -523,13 +523,26 @@ def _interpolate(self, node: fx.Node) -> relax.Var:
if size is None:
shape = self.shape_of(data)
assert isinstance(shape, relax.ShapeExpr)
# Determine spatial dimension indices based on layout
# NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
# NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
if self.default_image_layout == "NHWC":
spatial_start = 1
spatial_end = len(shape) - 1
else: # NCHW or other layouts
spatial_start = 2
spatial_end = len(shape)
Comment on lines +532 to +534
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The else branch currently covers "NCHW or other layouts". While NCHW is the default and likely the primary alternative, if only NCHW is explicitly supported besides NHWC, it might be clearer to make this condition elif self.default_image_layout == "NCHW": for explicitness. If other layouts are genuinely intended to fall into this category, a comment explaining this assumption would be beneficial for future maintainability.

Suggested change
else: # NCHW or other layouts
spatial_start = 2
spatial_end = len(shape)
else: # NCHW
spatial_start = 2
spatial_end = len(shape)


if isinstance(scale_factor, tuple):
assert len(scale_factor) == len(shape) - 2
assert len(scale_factor) == spatial_end - spatial_start
size = tuple(
int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape))
int(shape[i].value * scale_factor[i - spatial_start])
for i in range(spatial_start, spatial_end)
)
else:
size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape)))
size = tuple(
int(shape[i].value * scale_factor) for i in range(spatial_start, spatial_end)
)

if method.startswith("nearest"):
method = "nearest_neighbor"
Expand All @@ -545,7 +558,11 @@ def _interpolate(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(
relax.op.image.resize2d(
data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans
data,
size,
layout=self.default_image_layout,
method=method,
coordinate_transformation_mode=coord_trans,
)
)

Expand Down Expand Up @@ -1150,6 +1167,7 @@ def from_fx(
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: dict = None,
default_image_layout: str = "NCHW",
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program

Expand All @@ -1175,6 +1193,10 @@ def from_fx(
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as TorchFXImporter.convert_map

default_image_layout : str
The default layout for image operations (e.g., "NCHW" or "NHWC").
Default is "NCHW" which is the standard PyTorch layout.

Returns
-------
output : tvm.IRModule
Expand Down Expand Up @@ -1242,7 +1264,7 @@ def forward(self, input):
to print out the tabular representation of the PyTorch module, and then
check the placeholder rows in the beginning of the tabular.
"""
return TorchFXImporter().from_fx(
return TorchFXImporter(default_image_layout=default_image_layout).from_fx(
model,
input_info,
keep_params_as_input,
Expand Down
115 changes: 115 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3670,6 +3670,121 @@ def main(
verify_model(Interpolate4(), input_info, {}, expected4)


def test_interpolate_nhwc_layout():
# First verify backward compatibility - default should still be NCHW
input_info_nchw = [([1, 3, 10, 10], "float32")]

class InterpolateDefault(Module):
def forward(self, input):
return torch.nn.functional.interpolate(input, (5, 5))

@tvm.script.ir_module
class expected_default_nchw:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 5, 5), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d(
input_1,
(5, 5),
roi=[0.000000, 0.000000, 0.000000, 0.000000],
layout="NCHW",
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv
R.output(gv)
return gv

# Verify default behavior (no default_image_layout parameter) uses NCHW
graph_model_default = fx.symbolic_trace(InterpolateDefault())
with torch.no_grad():
mod_default = from_fx(graph_model_default, input_info_nchw)
tvm.ir.assert_structural_equal(mod_default, expected_default_nchw)

# Now test NHWC layout
input_info = [([1, 10, 10, 3], "float32")]

class InterpolateNHWC(Module):
def forward(self, input):
return torch.nn.functional.interpolate(input, (5, 5))

@tvm.script.ir_module
class expected_nhwc:
@R.function
def main(
input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
) -> R.Tensor((1, 5, 5, 3), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 5, 5, 3), dtype="float32") = R.image.resize2d(
input_1,
(5, 5),
roi=[0.000000, 0.000000, 0.000000, 0.000000],
layout="NHWC",
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 5, 5, 3), dtype="float32") = lv
R.output(gv)
return gv

# Test with NHWC layout
graph_model = fx.symbolic_trace(InterpolateNHWC())
with torch.no_grad():
mod = from_fx(graph_model, input_info, default_image_layout="NHWC")
tvm.ir.assert_structural_equal(mod, expected_nhwc)

# Test with bilinear interpolation and NHWC layout
class InterpolateNHWC2(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input, size=None, scale_factor=2.0, mode="bilinear", align_corners=False
)

@tvm.script.ir_module
class expected_nhwc2:
@R.function
def main(
input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
) -> R.Tensor((1, 20, 20, 3), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 20, 20, 3), dtype="float32") = R.image.resize2d(
input_1,
(20, 20),
roi=[0.000000, 0.000000, 0.000000, 0.000000],
layout="NHWC",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 20, 20, 3), dtype="float32") = lv
R.output(gv)
return gv

graph_model2 = fx.symbolic_trace(InterpolateNHWC2())
with torch.no_grad():
mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
tvm.ir.assert_structural_equal(mod2, expected_nhwc2)


def test_addmm():
input_info = [
([10, 10], "float32"),
Expand Down
Loading