diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3d6a632fb20f..94df0282c870 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1371,6 +1371,7 @@ def create_convert_map( "any.dim": self._any, "any.dims": self._any, "mean.dim": self._mean, + "mean.default": self._mean, "prod.default": self._prod, "std.correction": self._std, "sum.default": self._sum, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 01e16e7564ac..b7a36f947019 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4911,6 +4911,10 @@ class MeanKeepDim(Module): def forward(self, input: torch.Tensor): return input.mean(-1, keepdim=True) + class MeanWithoutDim(Module): + def forward(self, input: torch.Tensor): + return input.mean() + @I.ir_module class Expected1: @R.function @@ -4935,9 +4939,22 @@ def main( R.output(gv) return gv + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.mean(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(256, 256, dtype=torch.float32),) verify_model(Mean(), example_args, {}, Expected1) verify_model(MeanKeepDim(), example_args, {}, Expected2) + verify_model(MeanWithoutDim(), example_args, {}, Expected3) def test_sum():