From 94bb8677cefd64f39390ca58b947e4183c73f4ba Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 5 May 2025 11:50:44 +0000 Subject: [PATCH 1/6] add max_pool 1d and 3d op support and refactor 2d op --- .../torch/base_fx_graph_translator.py | 94 +++- .../torch/exported_program_translator.py | 2 + .../test_frontend_from_exported_program.py | 199 ++++++++ tests/python/relax/test_frontend_from_fx.py | 194 ++++++++ tests/python/relax/test_op_nn_pooling.py | 441 ++++++++++++++++++ 5 files changed, 929 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 48869767ad66..2808805e1fae 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -862,6 +862,48 @@ def _linear(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _max_pool1d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int]] = 1, + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + + stride = kernel_size if stride is None else stride + + result = self.block_builder.emit( + relax.op.nn.max_pool1d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCW", + ) + ) + + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _max_pool1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool2d_impl( self, x: relax.Expr, @@ -871,8 +913,13 @@ def _max_pool2d_impl( dilation: Optional[int] = 1, ceil_mode: Optional[bool] = False, ) -> relax.Var: + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None else stride - return self.block_builder.emit( + + result = self.block_builder.emit( relax.op.nn.max_pool2d( x, pool_size=kernel_size, @@ -884,6 +931,10 @@ def _max_pool2d_impl( ) ) + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + def _max_pool2d(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -895,6 +946,47 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool3d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1), + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + + stride = kernel_size if stride is None else stride + + result = self.block_builder.emit( + relax.op.nn.max_pool3d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCDHW", + ) + ) + + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _max_pool3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index df532fd1ea04..1f1cc81a87b9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -406,7 +406,9 @@ def create_convert_map( "group_norm.default": self._group_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, + "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, + "max_pool3d.default": self._max_pool3d, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, "unbind.int": self._unbind, "upsample_bilinear2d.vec": self._upsample_bilinear2d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bb33964ef2..c69505e66645 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2299,6 +2299,101 @@ def main( verify_model(model, example_args, binding, expected2) +def test_maxpool1d(): + class MaxPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=2) + + def forward(self, input): + return self.pool(input) + + class MaxPool1d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool1d(input, kernel_size=2) + + class MaxPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + # Example inputs + example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),) + example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),) + example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),) + + # Verify the models + verify_model(MaxPool1d(), example_args1, {}, expected1) + verify_model(MaxPool1d_functional(), example_args2, {}, expected2) + verify_model(MaxPool1d2(), example_args3, {}, expected3) + + def test_maxpool2d(): class MaxPool2d(Module): def __init__(self): @@ -2401,6 +2496,110 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_maxpool3d(): + class MaxPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool3d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class MaxPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2, 2, 2]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[2, 2, 2], + strides=[2, 2, 2], + dilation=[2, 2, 2], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class MaxPool3d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + # Example input tensors + example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),) + example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) + + # Verify the models with expected IR modules + verify_model(MaxPool3d(), example_args1, {}, expected1) + verify_model(MaxPool3d_functional(), example_args1, {}, expected1) + verify_model(MaxPool3d2(), example_args2, {}, expected2) + verify_model(MaxPool3d3(), example_args3, {}, expected3) + + def test_scaled_dot_product_attention(): class Attention1(Module): def forward(self, q, k, v): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 490a2309aa37..5b069e417808 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -984,6 +984,100 @@ def main( verify_model(Prelu2(), input_info, {}, expected) +def test_maxpool1d(): + input_info = [([1, 3, 10], "float32")] + + class MaxPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=2) + + def forward(self, input): + return self.pool(input) + + class MaxPool1d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool1d(input, kernel_size=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW" + ) + gv: R.Tensor((1, 3, 5), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[1], + dilation=[1], + padding=[1, 1], + layout="NCW", + out_layout="NCW" + ) + gv: R.Tensor((1, 3, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool1d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2, dilation=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 3), dtype="float32"): # Corrected here + with R.dataflow(): + lv: R.Tensor((1, 3, 3), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[2], + padding=[0, 0], + layout="NCW", + out_layout="NCW" + ) + gv: R.Tensor((1, 3, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool1d(), input_info, {}, expected1) + verify_model(MaxPool1d_functional(), input_info, {}, expected1) + verify_model(MaxPool1d2(), input_info, {}, expected2) + verify_model(MaxPool1d3(), input_info, {}, expected3) + + def test_maxpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1087,6 +1181,106 @@ def main( verify_model(MaxPool2d3(), input_info, {}, expected3) +def test_maxpool3d(): + input_info = [([1, 3, 10, 10, 10], "float32")] + + class MaxPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool3d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 2, 2]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 4, 4), dtype="float32"): # Fixed here + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[2, 2, 2], + strides=[2, 2, 2], + dilation=[1, 2, 2], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool3d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool3d(), input_info, {}, expected1) + verify_model(MaxPool3d_functional(), input_info, {}, expected1) + verify_model(MaxPool3d2(), input_info, {}, expected2) + verify_model(MaxPool3d3(), input_info, {}, expected3) + + def test_avgpool2d(): input_info = [([1, 3, 10, 10], "float32")] diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 2533a2fcadcb..6423e2adae1f 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -25,7 +25,11 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x1", R.Tensor((2, 3, 64), "float32")) + x2 = relax.Var("x2", R.Tensor((2, 3, 8, 28, 28), "float32")) + assert relax.op.nn.max_pool1d(x1).op == Op.get("relax.nn.max_pool1d") assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") + assert relax.op.nn.max_pool3d(x2).op == Op.get("relax.nn.max_pool3d") assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d") assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") @@ -35,6 +39,206 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def test_max_pool1d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor(ndim=3)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, pool_size=3), relax.TensorStructInfo((2, 3, 30), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2, 3, 16), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 3, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2, 3, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32") + ) + _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.max_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) + + +def test_max_pool1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + l = tir.Var("l", "int64") + c16 = tir.Var("c16", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, l), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, l, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool1d( + x0, pool_size=3, strides=3, padding=2, dilation=2 + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(l - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x1, layout="NCL16c", out_layout="NLC"), + relax.TensorStructInfo((n, l, c * 16), "float32"), + ) + + +def test_max_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) # For NCL layout + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) # For packed format like NCL16c + s2 = relax.Var("s", relax.ShapeStructInfo()) # Unknown shape + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x1, layout="NCL16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_max_pool1d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) # Shape: (N, C, L) + + _check_inference( + bb, + relax.op.nn.max_pool1d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 14), "float32"), + ) + +def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + l = tir.Var("l", "int64") # Length dimension for 1D + x = relax.Var("x", R.Tensor((n, c, l), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool1d( + x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True + ), + relax.TensorStructInfo( + (n, c, tvm.tir.floordiv(l, 2)), + "float32" + ), + ) + + +def test_max_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) + + _check_inference( + bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64") + ) + + +def test_max_pool1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1, dilation=1) + + assert max_pool1d.attrs.strides[0].dtype == "int64" + assert max_pool1d.attrs.padding[0].dtype == "int64" + assert max_pool1d.attrs.padding[1].dtype == "int64" + assert max_pool1d.attrs.dilation[0].dtype == "int64" + + +def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, pool_size=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, dilation=(1, 2)) + + +def test_max_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW")) # Invalid layout + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI")) # Invalid out_layout + + +def test_max_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) # 4D input (invalid for 1D pooling) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) # 2D input (also invalid) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x0)) # Should raise: expected 3D input for NCW layout + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x1)) # Should raise: expected ndim == 3 + + +def test_max_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) # Not a tensor + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) # Function, not tensor + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x1)) + + def test_max_pool2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -265,6 +469,243 @@ def test_max_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool2d(x1)) +def test_max_pool3d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 16, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=5)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 16, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool3d(x7), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32", vdev0) + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 14, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, pool_size=(3, 5, 3)), + relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool3d(x0, padding=1), relax.TensorStructInfo((2, 3, 18, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, padding=[1, 2, 3]), + relax.TensorStructInfo((2, 3, 18, 36, 38), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, strides=2), + relax.TensorStructInfo((2, 3, 8, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 16, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x1, layout="NDHWC"), + relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x0, out_layout="NDHWC"), + relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.max_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.nn.max_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) + + +def test_max_pool3d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + id = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool3d( + x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(id - 1, 3) + 1, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + + _check_inference( + bb, + relax.op.nn.max_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, id, ih, iw, c * 16), "float32"), + ) + + +def test_max_pool3d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) # NDHWC or NCDHW shape + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) # e.g., NCDHW16c + s2 = relax.Var("s", relax.ShapeStructInfo()) # unknown rank + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x2), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_max_pool3d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool3d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 14, 16, 16), "float32"), + ) + + +def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + id_ = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool3d( + x, pool_size=(3, 3, 3), strides=(2, 2, 2), padding=(1, 1, 1), dilation=(2, 2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_max_pool3d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + ) + + +def test_max_pool3d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + max_pool3d = relax.op.nn.max_pool3d(x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + + assert max_pool3d.attrs.strides[0].dtype == "int64" + assert max_pool3d.attrs.strides[1].dtype == "int64" + assert max_pool3d.attrs.strides[2].dtype == "int64" + assert max_pool3d.attrs.padding[0].dtype == "int64" + assert max_pool3d.attrs.padding[1].dtype == "int64" + assert max_pool3d.attrs.padding[2].dtype == "int64" + assert max_pool3d.attrs.padding[3].dtype == "int64" + assert max_pool3d.attrs.padding[4].dtype == "int64" + assert max_pool3d.attrs.dilation[0].dtype == "int64" + assert max_pool3d.attrs.dilation[1].dtype == "int64" + assert max_pool3d.attrs.dilation[2].dtype == "int64" + + +def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool3d(x, pool_size=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.max_pool3d(x, strides=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.max_pool3d(x, padding=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.max_pool3d(x, dilation=(1, 2, 3, 4)) + + +def test_max_pool3d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x, out_layout="OHWI")) + + +def test_max_pool3d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x1)) + + +def test_max_pool3d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool3d(x1)) + + def test_avg_pool2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") From 190bc8203c89209f9aaffa47f84375a475e129ad Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 5 May 2025 12:28:03 +0000 Subject: [PATCH 2/6] update the layout used in max pool 1d op --- tests/python/relax/test_op_nn_pooling.py | 48 ++++++++++++------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 6423e2adae1f..5eee1a1b5d0a 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -81,11 +81,11 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") c = tir.Var("c", "int64") - l = tir.Var("l", "int64") + w = tir.Var("w", "int64") c16 = tir.Var("c16", "int64") - x0 = relax.Var("x", R.Tensor((n, c, l), "float32")) - x1 = relax.Var("x", R.Tensor((n, c, l, c16), "float32")) + x0 = relax.Var("x", R.Tensor((n, c, w), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, w, c16), "float32")) _check_inference( bb, @@ -96,23 +96,23 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(l - 1, 3) + 1, + tvm.tir.floordiv(w - 1, 3) + 1, ), "float32", ), ) _check_inference( bb, - relax.op.nn.max_pool1d(x1, layout="NCL16c", out_layout="NLC"), - relax.TensorStructInfo((n, l, c * 16), "float32"), + relax.op.nn.max_pool1d(x1, layout="NCW16c", out_layout="NWC"), + relax.TensorStructInfo((n, w, c * 16), "float32"), ) def test_max_pool1d_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) # For NCL layout - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) # For packed format like NCL16c - s2 = relax.Var("s", relax.ShapeStructInfo()) # Unknown shape + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) @@ -123,7 +123,7 @@ def test_max_pool1d_infer_struct_info_shape_var(): ) _check_inference( bb, - relax.op.nn.max_pool1d(x1, layout="NCL16c"), + relax.op.nn.max_pool1d(x1, layout="NCW16c"), relax.TensorStructInfo(dtype="float32", ndim=4), ) _check_inference( @@ -135,7 +135,7 @@ def test_max_pool1d_infer_struct_info_shape_var(): def test_max_pool1d_infer_struct_info_ceil_mode(): bb = relax.BlockBuilder() - x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) # Shape: (N, C, L) + x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) _check_inference( bb, @@ -152,8 +152,8 @@ def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") c = tir.Var("c", "int64") - l = tir.Var("l", "int64") # Length dimension for 1D - x = relax.Var("x", R.Tensor((n, c, l), "float32")) + w = tir.Var("w", "int64") + x = relax.Var("x", R.Tensor((n, c, w), "float32")) _check_inference( bb, @@ -161,7 +161,7 @@ def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True ), relax.TensorStructInfo( - (n, c, tvm.tir.floordiv(l, 2)), + (n, c, tvm.tir.floordiv(w, 2)), "float32" ), ) @@ -210,26 +210,26 @@ def test_max_pool1d_infer_struct_info_wrong_layout_string(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW")) # Invalid layout + bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW")) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI")) # Invalid out_layout + bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI")) def test_max_pool1d_wrong_input_ndim(): bb = relax.BlockBuilder() - x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) # 4D input (invalid for 1D pooling) - x1 = relax.Var("x", R.Tensor("float32", ndim=2)) # 2D input (also invalid) + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=5)) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.max_pool1d(x0)) # Should raise: expected 3D input for NCW layout + bb.normalize(relax.op.nn.max_pool1d(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.max_pool1d(x1)) # Should raise: expected ndim == 3 + bb.normalize(relax.op.nn.max_pool1d(x1)) def test_max_pool1d_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) # Not a tensor + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) # Function, not tensor with pytest.raises(TVMError): @@ -577,9 +577,9 @@ def test_max_pool3d_infer_struct_info_shape_symbolic(): def test_max_pool3d_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) # NDHWC or NCDHW shape - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) # e.g., NCDHW16c - s2 = relax.Var("s", relax.ShapeStructInfo()) # unknown rank + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s2 = relax.Var("s", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) From 49d2918c1d5f95b704589a87c0a6e3f9ca30241b Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 6 May 2025 04:17:52 +0000 Subject: [PATCH 3/6] add mappings into fx translator and fix lint issues --- .../tvm/relax/frontend/torch/fx_translator.py | 4 + src/relax/op/nn/pooling.cc | 2 +- tests/python/relax/test_frontend_from_fx.py | 18 +++-- tests/python/relax/test_op_nn_pooling.py | 77 ++++++++++--------- 4 files changed, 59 insertions(+), 42 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5f65f86a4303..4df33a465795 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -661,7 +661,9 @@ def create_convert_map( nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear_module, + nn.MaxPool1d: self._max_pool1d_module, nn.MaxPool2d: self._max_pool2d_module, + nn.MaxPool3d: self._max_pool3d_module, nn.modules.sparse.Embedding: self._embedding_module, nn.PixelShuffle: self._pixel_shuffle_module, # tensor manipulation @@ -772,7 +774,9 @@ def create_convert_map( "interpolate": self._interpolate, "layer_norm": self._layer_norm, "linear": self._linear, + "max_pool2d": self._max_pool1d, "max_pool2d": self._max_pool2d, + "max_pool2d": self._max_pool3d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], "unbind": self._unbind, diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 565e6a00c60d..391edda9ef38 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -95,7 +95,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_w += attrs->strides[1] - 1; + numerator_w += attrs->strides[0] - 1; } out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 5b069e417808..bd7e692d6f7d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1005,7 +1005,9 @@ def forward(self, input): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 5), dtype="float32"): + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5), dtype="float32"): with R.dataflow(): lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.max_pool1d( input_1, @@ -1014,7 +1016,7 @@ def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 5), dilation=[1], padding=[0, 0], layout="NCW", - out_layout="NCW" + out_layout="NCW", ) gv: R.Tensor((1, 3, 5), dtype="float32") = lv R.output(gv) @@ -1031,7 +1033,9 @@ def forward(self, input): @tvm.script.ir_module class expected2: @R.function - def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 10), dtype="float32"): + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10), dtype="float32"): with R.dataflow(): lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.max_pool1d( input_1, @@ -1040,7 +1044,7 @@ def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 10), dilation=[1], padding=[1, 1], layout="NCW", - out_layout="NCW" + out_layout="NCW", ) gv: R.Tensor((1, 3, 10), dtype="float32") = lv R.output(gv) @@ -1057,7 +1061,9 @@ def forward(self, input): @tvm.script.ir_module class expected3: @R.function - def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 3), dtype="float32"): # Corrected here + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 3), dtype="float32"): # Corrected here with R.dataflow(): lv: R.Tensor((1, 3, 3), dtype="float32") = R.nn.max_pool1d( input_1, @@ -1066,7 +1072,7 @@ def main(input_1: R.Tensor((1, 3, 10), dtype="float32")) -> R.Tensor((1, 3, 3), dilation=[2], padding=[0, 0], layout="NCW", - out_layout="NCW" + out_layout="NCW", ) gv: R.Tensor((1, 3, 3), dtype="float32") = lv R.output(gv) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 5eee1a1b5d0a..da51e67de31c 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -49,9 +49,7 @@ def test_max_pool1d_infer_struct_info(): x4 = relax.Var("x", R.Tensor()) x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) - _check_inference( - bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32") - ) + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) _check_inference( bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) ) @@ -62,18 +60,23 @@ def test_max_pool1d_infer_struct_info(): bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2, 3, 16), "float32") ) _check_inference( - bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 3, 32), "float32") + bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 3, 34), "float32") ) _check_inference( bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2, 3, 32), "float32") ) _check_inference( - bb, relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"), - relax.TensorStructInfo((2, 32, 3), "float32") + bb, + relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo(dtype="float32", ndim=3) ) - _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.nn.max_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference( + bb, relax.op.nn.max_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) + ) _check_inference(bb, relax.op.nn.max_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) @@ -89,9 +92,7 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): _check_inference( bb, - relax.op.nn.max_pool1d( - x0, pool_size=3, strides=3, padding=2, dilation=2 - ), + relax.op.nn.max_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), relax.TensorStructInfo( ( n, @@ -148,6 +149,7 @@ def test_max_pool1d_infer_struct_info_ceil_mode(): relax.TensorStructInfo((2, 3, 14), "float32"), ) + def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") @@ -157,13 +159,8 @@ def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): _check_inference( bb, - relax.op.nn.max_pool1d( - x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True - ), - relax.TensorStructInfo( - (n, c, tvm.tir.floordiv(w, 2)), - "float32" - ), + relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(w, 2)), "float32"), ) @@ -173,15 +170,9 @@ def test_max_pool1d_infer_struct_info_more_input_dtype(): x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) - _check_inference( - bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16") - ) - _check_inference( - bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8") - ) - _check_inference( - bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64") - ) + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) def test_max_pool1d_stride_padding_dilation_int64(): @@ -230,7 +221,9 @@ def test_max_pool1d_wrong_input_ndim(): def test_max_pool1d_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) # Function, not tensor + x1 = relax.Var( + "x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32")) + ) with pytest.raises(TVMError): bb.normalize(relax.op.nn.max_pool1d(x0)) @@ -498,7 +491,9 @@ def test_max_pool3d_infer_struct_info(): relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"), ) _check_inference( - bb, relax.op.nn.max_pool3d(x0, padding=1), relax.TensorStructInfo((2, 3, 18, 34, 34), "float32") + bb, + relax.op.nn.max_pool3d(x0, padding=1), + relax.TensorStructInfo((2, 3, 18, 34, 34), "float32"), ) _check_inference( bb, @@ -528,7 +523,7 @@ def test_max_pool3d_infer_struct_info(): _check_inference( bb, relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), - relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + relax.TensorStructInfo((2, 16, 32, 32, 4, 16), "float32"), ) _check_inference( bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) @@ -611,7 +606,7 @@ def test_max_pool3d_infer_struct_info_ceil_mode(): _check_inference( bb, relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 14, 16, 16), "float32"), + relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), ) @@ -627,9 +622,17 @@ def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): _check_inference( bb, relax.op.nn.max_pool3d( - x, pool_size=(3, 3, 3), strides=(2, 2, 2), padding=(1, 1, 1), dilation=(2, 2, 2), ceil_mode=True + x, + pool_size=(3, 3, 3), + strides=(2, 2, 2), + padding=(1, 1, 1), + dilation=(2, 2, 2), + ceil_mode=True, + ), + relax.TensorStructInfo( + (n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), + "float32", ), - relax.TensorStructInfo((n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), ) @@ -641,7 +644,9 @@ def test_max_pool3d_infer_struct_info_more_input_dtype(): _check_inference( bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") ) - _check_inference(bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) _check_inference( bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") ) @@ -649,7 +654,9 @@ def test_max_pool3d_infer_struct_info_more_input_dtype(): def test_max_pool3d_stride_padding_dilation_int64(): x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) - max_pool3d = relax.op.nn.max_pool3d(x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + max_pool3d = relax.op.nn.max_pool3d( + x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) + ) assert max_pool3d.attrs.strides[0].dtype == "int64" assert max_pool3d.attrs.strides[1].dtype == "int64" From 7591ea4cdde857c3444e129419b8cd5d44e6a2e1 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 6 May 2025 05:55:05 +0000 Subject: [PATCH 4/6] fix missing incorrect mappings and add module func --- .../tvm/relax/frontend/torch/fx_translator.py | 26 +++++++++++++++++-- tests/python/relax/test_op_nn_pooling.py | 4 +-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4df33a465795..9e5d28336d33 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -449,6 +449,17 @@ def _linear_module(self, node: fx.Node) -> relax.Var: bias = self.params.get(module.bias, None) return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _max_pool1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -460,6 +471,17 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool3d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -774,9 +796,9 @@ def create_convert_map( "interpolate": self._interpolate, "layer_norm": self._layer_norm, "linear": self._linear, - "max_pool2d": self._max_pool1d, + "max_pool1d": self._max_pool1d, "max_pool2d": self._max_pool2d, - "max_pool2d": self._max_pool3d, + "max_pool3d": self._max_pool3d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], "unbind": self._unbind, diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index da51e67de31c..a75f9977a898 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -221,9 +221,7 @@ def test_max_pool1d_wrong_input_ndim(): def test_max_pool1d_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) - x1 = relax.Var( - "x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32")) - ) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) with pytest.raises(TVMError): bb.normalize(relax.op.nn.max_pool1d(x0)) From 9e7acb560de8be147f9c76b779ec1ce3a558fb3c Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 6 May 2025 06:40:17 +0000 Subject: [PATCH 5/6] update output tensor struct info for maxpool1d --- tests/python/relax/test_op_nn_pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index a75f9977a898..0d58af1cbec3 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -146,7 +146,7 @@ def test_max_pool1d_infer_struct_info_ceil_mode(): _check_inference( bb, relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 14), "float32"), + relax.TensorStructInfo((2, 3, 15), "float32"), ) From 4ab72c06e96ba34569e4e3a0ff5511282164c880 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 7 May 2025 05:39:22 +0000 Subject: [PATCH 6/6] add docs for handling edge cases --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2808805e1fae..dbcb5467f14b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -871,6 +871,7 @@ def _max_pool1d_impl( dilation: Optional[int] = 1, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 3D by adding batch dim if input is 2D x_ndim = x.struct_info.ndim if x_ndim == 2: x = relax.op.expand_dims(x, axis=0) @@ -889,6 +890,7 @@ def _max_pool1d_impl( ) ) + # Remove added batch dim from result if x_ndim == 2: result = relax.op.squeeze(result, axis=[0]) return result @@ -913,6 +915,7 @@ def _max_pool2d_impl( dilation: Optional[int] = 1, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 4D by adding batch dim if input is 3D x_ndim = x.struct_info.ndim if x_ndim == 3: x = relax.op.expand_dims(x, axis=0) @@ -931,6 +934,7 @@ def _max_pool2d_impl( ) ) + # Remove added batch dim from result if x_ndim == 3: result = relax.op.squeeze(result, axis=[0]) return result @@ -955,6 +959,7 @@ def _max_pool3d_impl( dilation: Optional[int] = 1, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 5D by adding batch dim if input is 4D x_ndim = x.struct_info.ndim if x_ndim == 4: x = relax.op.expand_dims(x, axis=0) @@ -973,6 +978,7 @@ def _max_pool3d_impl( ) ) + # Remove added batch dim from result if x_ndim == 4: result = relax.op.squeeze(result, axis=[0]) return result