From 1f99857cb6eb63c30a4c54c92d4349b25cce622f Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Tue, 2 Dec 2025 00:10:04 +0800 Subject: [PATCH] Add support for masked_select --- .../torch/base_fx_graph_translator.py | 21 +++++++++++ .../torch/exported_program_translator.py | 11 ++++++ python/tvm/script/ir_builder/relax/ir.py | 2 + .../test_frontend_from_exported_program.py | 37 +++++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e9a9cdd9394f..58c729b62f3b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -23,6 +23,7 @@ import math from typing import Callable, Dict, Optional, Tuple, Union, List +import tvm from tvm import relax, tir @@ -2384,6 +2385,26 @@ def _masked_fill(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.where(mask, values, x)) + def _masked_select(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + mask = self.env[node.args[1]] + + data_shape = self.shape_of(data) + mask_shape = self.shape_of(mask) + shapes_equal = tvm.ir.structural_equal(data_shape, mask_shape) + + if not shapes_equal: + mask = self.block_builder.emit(relax.op.broadcast_to(mask, data_shape)) + + data_flat = self.block_builder.emit(relax.op.reshape(data, [-1])) + mask_flat = self.block_builder.emit(relax.op.reshape(mask, [-1])) + indices = self.block_builder.emit(relax.op.nonzero(mask_flat)) + indices_1d = self.block_builder.emit(relax.op.squeeze(indices, axis=[0])) + + result = self.block_builder.emit(relax.op.take(data_flat, indices_1d, axis=0)) + + return result + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2ec61796c31a..b0786999b215 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1153,6 +1153,11 @@ def _as_strided(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(x, size)) + ########## Symbolic Shape Constraints ########## + + def _symbolic_comparison(self, _: fx.Node) -> relax.Expr: + return self.block_builder.emit(relax.const(True, dtype="bool")) + ########## Others ########## def create_convert_map( @@ -1457,6 +1462,7 @@ def create_convert_map( "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, "masked_fill_.Scalar": self._inplace_masked_fill, + "masked_select.default": self._masked_select, "new_ones.default": self._new_ones, "new_zeros.default": self._new_zeros, "one_hot.default": self._one_hot, @@ -1477,6 +1483,11 @@ def create_convert_map( "item.default": self._item, "sym_size.int": self._sym_size_int, "_local_scalar_dense.default": self._item, + # symbolic shape constraints (no-ops for compilation) + "sym_constrain_range_for_size.default": lambda node: self.env[node.args[0]], + "_assert_scalar.default": lambda node: self.env[node.args[0]], + "ge": self._symbolic_comparison, + "le": self._symbolic_comparison, } def _process_derived_symbol( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index f221a1308965..141361a729c4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -137,6 +137,7 @@ multiply, negative, nn, + nonzero, not_equal, null_value, ones, @@ -882,6 +883,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "multinomial_from_uniform", "multiply", "negative", + "nonzero", "not_equal", "null_value", "ones", diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0658dbfaf31e..4480e25f7d6c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6231,6 +6231,43 @@ def main( verify_model(Masked_Fill_Inplace(), example_args, {}, Expected) +def test_masked_select(): + class MaskedSelect(Module): + def forward(self, data: torch.Tensor, mask: torch.Tensor): + return torch.masked_select(data, mask) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((2, 3), dtype="float32"), mask: R.Tensor((2, 3), dtype="bool") + ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)): + R.func_attr( + { + "tir_var_lower_bound": {"u0": 0, "u1": 0}, + "tir_var_upper_bound": {"u0": 6, "u1": 6}, + } + ) + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.reshape(data, R.shape([6])) + lv1: R.Tensor((6,), dtype="bool") = R.reshape(mask, R.shape([6])) + lv2: R.Tensor(dtype="int64", ndim=2) = R.nonzero(lv1) + lv3: R.Tensor(dtype="int64", ndim=1) = R.squeeze(lv2, axis=[0]) + lv4: R.Tensor(dtype="float32", ndim=1) = R.take(lv, lv3, axis=0, mode="fast") + lv5: R.Tensor((), dtype="int64") = R.const(0, "int64") + lv6: R.Tensor((), dtype="bool") = R.const(True, "bool") + lv7: R.Tensor((), dtype="bool") = R.const(True, "bool") + gv: R.Tuple(R.Tensor(dtype="float32", ndim=1)) = (lv4,) + R.output(gv) + return gv + + example_args = ( + torch.randn(2, 3, dtype=torch.float32), + torch.tensor([[True, False, True], [False, True, False]]), + ) + verify_model(MaskedSelect(), example_args, {}, Expected) + + def test_new_ones(): class NewOnes(Module): def forward(self, x):