From c97d03b9b6570d19fc410b908807ee303416f15a Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:44:47 +0800 Subject: [PATCH] Fix batch normalization training mode correctness --- .../torch/exported_program_translator.py | 34 ++++++--- .../test_frontend_from_exported_program.py | 76 +++++++++++++++++++ 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7af8774ee3a1..1f60d02a79ea 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -116,7 +116,7 @@ def _rsqrt(self, node: fx.Node) -> relax.Var: ########## Neural Network ########## - def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: + def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = False) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -149,7 +149,7 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: if track_running_stats: training = True - return self.block_builder.emit( + bn_result = self.block_builder.emit( relax.op.nn.batch_norm( data=x, gamma=weight, @@ -160,21 +160,33 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: epsilon=eps, momentum=momentum, training=training, - )[0] + ) ) + if return_tuple: + return bn_result + else: + # Return only the output tensor (for backward compatibility) + return self.block_builder.emit(bn_result[0]) + def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in training mode - # TODO does not have correctness! - # TODO we need to store the running mean and variance returned by the - # previous call to batch_norm and pass it again - training = True - return self._batch_norm(node, training) + bn_tuple = self._batch_norm(node, training=True, return_tuple=True) + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + + output = self.block_builder.emit(bn_tuple[0]) + new_running_mean = self.block_builder.emit(bn_tuple[1]) + reserve = self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype)) + + return self.block_builder.emit( + relax.Tuple([output, new_running_mean, reserve, reserve, reserve]) + ) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: - # This method is called for batch_norm in eval mode - training = False - return self._batch_norm(node, training) + return self._batch_norm(node, training=False, return_tuple=False) def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var: import numpy as np diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 98c6c6d01485..7e4c6d328a74 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1788,6 +1788,82 @@ def main( } verify_model(model_2, example_args, binding_2, expected2) + class BatchNorm2dTraining(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((2, 3, 4, 4), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64")) + lv1: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=0.1, + center=True, + scale=True, + momentum=1.0, + training=True, + ) + lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0] + lv3: R.Tensor((3,), dtype="float32") = lv1[1] + lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), dtype="float32") + lv5: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = (lv2, lv3, lv4, lv4, lv4) + lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0] + lv7: R.Tensor((3,), dtype="float32") = lv5[3] + lv8: R.Tensor((3,), dtype="float32") = lv5[4] + gv: R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ) = (lv7, lv8, lv, lv6) + R.output(gv) + return gv + + example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),) + + model_3 = BatchNorm2dTraining() + model_3.train() # Set to training mode + binding_3 = { + "w1": model_3.bn.weight.detach().numpy(), + "w2": model_3.bn.bias.detach().numpy(), + "w3": model_3.bn.running_mean.detach().numpy(), + "w4": model_3.bn.running_var.detach().numpy(), + } + verify_model(model_3, example_args_train, binding_3, expected3) + def test_adaptive_avgpool1d(): class AdaptiveAvgPool1d0(torch.nn.Module):