From 239aab28bbc46300a009ee9e999039a3f93f586f Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Fri, 12 Dec 2025 16:06:30 +0700 Subject: [PATCH 1/5] [Relax][Torch] Fix issue sum op when without dim and keep dim - WithoutDim: args[1] = [] and still pass into relax.op.sum -> result incorrect - KeepDim: Before keepdim value get only from node.kwargs and no pass into relax.op.sum. Now keepdim get more from args[2] and pass into. --- .../torch/base_fx_graph_translator.py | 10 +++--- .../test_frontend_from_exported_program.py | 36 +++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 47eb66621008..f7d54a6216a7 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1628,10 +1628,12 @@ def _std(self, node: fx.Node) -> relax.Var: def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + dim = None + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim)) def _var(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 01e16e7564ac..cd16d1909f96 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4944,6 +4944,14 @@ def test_sum(): class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) + + class SumKeepDim(Module): + def forward(self, x): + return torch.sum(x, (2, 1), keepdim=True) + + class SumWithoutDim(Module): + def forward(self, x): + return torch.sum(x) @tvm.script.ir_module class expected1: @@ -4958,8 +4966,36 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=True) + gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) verify_model(Sum(), example_args, {}, expected1) + verify_model(SumKeepDim(), example_args, {}, expected2) + verify_model(SumWithoutDim(), example_args, {}, expected3) def test_argmax_argmin(): From 0019f82f5aaa0c276b90c8a2182596b23bd4c950 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Fri, 12 Dec 2025 20:40:45 +0700 Subject: [PATCH 2/5] Remove and Rerun --- tests/python/relax/test_frontend_from_exported_program.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cd16d1909f96..f1bf4828719a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4972,7 +4972,6 @@ class expected2: def main( inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")): - # block 0 with R.dataflow(): lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=True) gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,) @@ -4985,7 +4984,6 @@ class expected3: def main( inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((), dtype="float32")): - # block 0 with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, keepdims=False) gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) From 5432f24ea011c6e7675d219a00148ea36f56faec Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Sat, 13 Dec 2025 10:23:30 +0700 Subject: [PATCH 3/5] Fix lint error in test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f1bf4828719a..2a58b50393f0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4944,11 +4944,11 @@ def test_sum(): class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) - + class SumKeepDim(Module): def forward(self, x): return torch.sum(x, (2, 1), keepdim=True) - + class SumWithoutDim(Module): def forward(self, x): return torch.sum(x) @@ -4966,7 +4966,7 @@ def main( R.output(gv) return gv - @tvm.script.ir_module + @tvm.script.ir_module class expected2: @R.function def main( From 51ccf5838998ab66f07c2e0404b1414e258b3d13 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Sat, 13 Dec 2025 10:51:58 +0700 Subject: [PATCH 4/5] Fix test_cross_entropy with sum model fixed --- .../relax/test_frontend_from_exported_program.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2a58b50393f0..7328bf081439 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7874,7 +7874,7 @@ def forward(self, x): @tvm.script.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) @@ -7897,11 +7897,11 @@ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype=" lv12: R.Tensor((4,), dtype="bool") = R.not_equal( R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") ) - lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) - lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") - lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) - lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) - gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) + lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None, keepdims=False) + lv14: R.Tensor((), dtype="float32") = R.astype(lv13, dtype="float32") + lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,) R.output(gv) return gv From 91e01650b6652f65fcd043ac201b2bb85d7bcd68 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Sat, 13 Dec 2025 11:29:20 +0700 Subject: [PATCH 5/5] Fix line length lint error --- tests/python/relax/test_frontend_from_exported_program.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7328bf081439..4a84b50cc9d9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4973,7 +4973,9 @@ def main( inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=True) + lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum( + inp_0, axis=[2, 1], keepdims=True + ) gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,) R.output(gv) return gv