diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ac79024acfb9..cd124171b10e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -891,6 +891,49 @@ def _zeros(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.zeros(size, dtype)) + def _sparse_mm(self, node: fx.Node) -> relax.Var: + """Handle sparse matrix multiplication by converting sparse tensor to dense.""" + args = self.retrieve_args(node) + sparse_input = args[0] + dense_input = args[1] + # Convert sparse tensor to dense if needed + # Note: sparse_input should already be converted to dense in _convert_pytorch_tensor_to_tvm + # Use regular matrix multiplication + return self.block_builder.emit( + relax.op.linear_algebra.matmul(sparse_input, dense_input, out_dtype="float32") + ) + + def _sparse_addmm(self, node: fx.Node) -> relax.Var: + """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1, mat2)).""" + args = self.retrieve_args(node) + input_tensor = args[0] # beta * input + sparse_mat1 = args[1] # sparse matrix + dense_mat2 = args[2] # dense matrix + alpha = node.kwargs.get("alpha", 1.0) + beta = node.kwargs.get("beta", 1.0) + + # Convert sparse tensor to dense if needed + # Note: sparse_mat1 should already be converted to dense in _convert_pytorch_tensor_to_tvm + # Compute alpha * sparse_mm(mat1, mat2) + matmul_result = self.block_builder.emit( + relax.op.linear_algebra.matmul(sparse_mat1, dense_mat2, out_dtype="float32") + ) + + if alpha != 1.0: + alpha_const = relax.const(alpha, matmul_result.struct_info.dtype) + matmul_result = self.block_builder.emit(relax.op.multiply(matmul_result, alpha_const)) + + # Compute beta * input + alpha * matmul_result + if beta != 0.0: + if beta != 1.0: + beta_const = relax.const(beta, input_tensor.struct_info.dtype) + input_scaled = self.block_builder.emit(relax.op.multiply(input_tensor, beta_const)) + else: + input_scaled = input_tensor + return self.block_builder.emit(relax.op.add(input_scaled, matmul_result)) + else: + return matmul_result + def _grid_sampler_2d(self, node: fx.Node) -> relax.Var: """Convert torch.nn.functional.grid_sample to relax.op.image.grid_sample.""" args = self.retrieve_args(node) @@ -1184,6 +1227,8 @@ def create_convert_map( "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "addmm.default": self._addmm, + "_sparse_mm.default": self._sparse_mm, + "_sparse_addmm.default": self._sparse_addmm, "avg_pool1d.default": self._avg_pool1d, "avg_pool2d.default": self._avg_pool2d, "avg_pool3d.default": self._avg_pool3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 78a8a09a3cf4..fe3ff28aea0f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -32,14 +32,29 @@ def verify_model( - torch_model, example_args, binding, expected, dynamic_shapes=None, run_ep_decomposition=True + torch_model, + example_args, + binding, + expected, + dynamic_shapes=None, + run_ep_decomposition=True, + keep_params_as_input=False, + unwrap_unit_return_tuple=False, + no_bind_return_tuple=False, + map_free_vars=False, ): exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program, run_ep_decomposition=run_ep_decomposition) + mod = from_exported_program( + exported_program, + run_ep_decomposition=run_ep_decomposition, + keep_params_as_input=keep_params_as_input, + unwrap_unit_return_tuple=unwrap_unit_return_tuple, + no_bind_return_tuple=no_bind_return_tuple, + ) binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) - tvm.ir.assert_structural_equal(mod, expected) + tvm.ir.assert_structural_equal(mod, expected, map_free_vars=map_free_vars) operator_basic_unary = [ @@ -1788,6 +1803,82 @@ def main( } verify_model(model_2, example_args, binding_2, expected2) + class BatchNorm2dTraining(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((2, 3, 4, 4), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64")) + lv1: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=0.1, + center=True, + scale=True, + momentum=1.0, + training=True, + ) + lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0] + lv3: R.Tensor((3,), dtype="float32") = lv1[1] + lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), dtype="float32") + lv5: R.Tuple( + R.Tensor((2, 3, 4, 4), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = (lv2, lv3, lv4, lv4, lv4) + lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0] + lv7: R.Tensor((3,), dtype="float32") = lv5[3] + lv8: R.Tensor((3,), dtype="float32") = lv5[4] + gv: R.Tuple( + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((), dtype="int64"), + R.Tensor((2, 3, 4, 4), dtype="float32"), + ) = (lv7, lv8, lv, lv6) + R.output(gv) + return gv + + example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),) + + model_3 = BatchNorm2dTraining() + model_3.train() # Set to training mode + binding_3 = { + "w1": model_3.bn.weight.detach().numpy(), + "w2": model_3.bn.bias.detach().numpy(), + "w3": model_3.bn.running_mean.detach().numpy(), + "w4": model_3.bn.running_var.detach().numpy(), + } + verify_model(model_3, example_args_train, binding_3, expected3) + def test_adaptive_avgpool1d(): class AdaptiveAvgPool1d0(torch.nn.Module): @@ -1947,6 +2038,63 @@ def main( verify_model(Addmm2(), example_args, {}, expected2) +def test_sparse_addmm(): + class SparseAddmm1(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.sparse.addmm(x1, x2, x3) + + class SparseAddmm2(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.sparse.addmm(x1, x2, x3, beta=0.8, alpha=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) + lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + + verify_model(SparseAddmm1(), example_args, {}, expected1) + verify_model(SparseAddmm2(), example_args, {}, expected2) + + def test_avg_pool1d(): class AvgPool1d1(Module): def __init__(self): @@ -4703,6 +4851,43 @@ def main( verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) +def test_interpolate_antialiased(): + """Test bilinear interpolation with antialiasing enabled.""" + + class InterpolateBilinearAA(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, size=(64, 64), mode="bilinear", align_corners=False, antialias=True + ) + + @tvm.script.ir_module + class expected_bilinear_aa: + @R.function + def main( + input: R.Tensor((1, 3, 32, 32), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 64, 64), dtype="float32") = R.image.resize2d( + input, + R.shape([64, 64]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),) + verify_model(InterpolateBilinearAA(), example_args, {}, expected_bilinear_aa) + + def test_mean(): class Mean(Module): def forward(self, input): @@ -6245,6 +6430,7 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) model = Conv2D1() + exported_program = torch.export.export(model, example_args) mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = detach_params(mod) @@ -6281,9 +6467,7 @@ def main( return gv example_args = (torch.randn(256, 256, dtype=torch.float32),) - exported_program = export(Identity(), args=example_args) - mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True) - tvm.ir.assert_structural_equal(mod, Expected) + verify_model(Identity(), example_args, {}, Expected, unwrap_unit_return_tuple=True) def test_no_bind_return_tuple(): @@ -6311,9 +6495,7 @@ def main( torch.randn(256, 256, dtype=torch.float32), torch.randn(256, 256, dtype=torch.float32), ) - exported_program = export(Identity(), args=example_args) - mod = from_exported_program(exported_program, no_bind_return_tuple=True) - tvm.ir.assert_structural_equal(mod, Expected) + verify_model(Identity(), example_args, {}, Expected, no_bind_return_tuple=True) def test_empty_like(): @@ -7206,6 +7388,7 @@ def main( lhs: R.Tensor((B, 4), dtype="float32"), rhs: R.Tensor((B, 4), dtype="float32"), ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) with R.dataflow(): lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs) gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,) @@ -7615,6 +7798,39 @@ def main( verify_model(MatrixMultiply(), example_args, {}, Expected) +def test_sparse_mm(): + class SparseMatrixMultiply(Module): + def forward(self, sparse_input, dense_input): + return torch.sparse.mm(sparse_input, dense_input) + + indices = torch.tensor([[0, 1, 2], [2, 0, 1]]) + values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + sparse_input = torch.sparse_coo_tensor(indices, values, size=(3, 100)) + dense_input = torch.randn(100, 50, dtype=torch.float32) + + example_args = (sparse_input, dense_input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + sparse_input: R.Tensor((3, 100), dtype="float32"), + dense_input: R.Tensor((100, 50), dtype="float32"), + ) -> R.Tuple(R.Tensor((3, 50), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 50), dtype="float32") = R.full( + R.shape([3, 50]), R.const(0.0, "float32"), dtype="float32" + ) + lv1: R.Tensor((3, 50), dtype="float32") = R.matmul( + sparse_input, dense_input, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((3, 50), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(SparseMatrixMultiply(), example_args, {}, Expected) + + def test_lstm(): class BasicLSTM(nn.Module): def __init__(self): @@ -7801,10 +8017,15 @@ def main( example_args = (torch.randn(8, 4), torch.randn(8, 4)) batch = torch.export.Dim("batch", min=1, max=64) dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} - exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) - tvm.ir.assert_structural_equal(mod, Expected) + verify_model( + DynamicModel(), + example_args, + {}, + Expected, + dynamic_shapes=dynamic_shapes, + map_free_vars=True, + ) def test_dynamic_shape_with_addition_constraints(): @@ -7835,10 +8056,10 @@ def main( batch = torch.export.Dim("batch", min=1, max=64) example_args = (torch.randn(8, 4), torch.randn(9, 4)) dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}} - exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) - tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) def test_dynamic_shape_with_subtraction_constraints(): @@ -7869,10 +8090,10 @@ def main( batch = torch.export.Dim("batch", min=1, max=64) example_args = (torch.randn(8, 4), torch.randn(7, 4)) dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}} - exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) - tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) def test_dynamic_shape_with_multiplication_constraints(): @@ -7903,10 +8124,43 @@ def main( batch = torch.export.Dim("batch", min=1, max=64) example_args = (torch.randn(8, 4), torch.randn(16, 4)) dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}} - exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) - tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) + verify_model( + ConcatModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes, map_free_vars=True + ) + + +def test_dynamic_shape_with_unbounded_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Tensor(x, x) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 2}}) + with R.dataflow(): + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x, x) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 4),) + batch = torch.export.Dim("batch", min=2) + dynamic_shapes = {"x": {0: batch}} + + verify_model( + DynamicModel(), + example_args, + {}, + Expected, + dynamic_shapes=dynamic_shapes, + map_free_vars=True, + ) def test_sym_size_int(): @@ -7955,6 +8209,7 @@ def main( x: R.Tensor(("s0", 3, 4), dtype="float32") ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")): s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_lower_bound": {"s0": 0}}) with R.dataflow(): lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, R.shape([s0, 12])) gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)