Skip to content

Commit 602ae56

Browse files
ArchermmtShiboXing
authored andcommitted
[Relax] support masked_scatter (apache#17525)
* support masked_scatter * remove logging
1 parent 011e09c commit 602ae56

8 files changed

Lines changed: 251 additions & 4 deletions

File tree

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,31 @@ def _masked_fill(self, node: fx.Node) -> relax.Var:
472472
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
473473
return self.block_builder.emit(relax.op.where(mask, values, x))
474474

475+
def _masked_scatter(self, node: fx.Node) -> relax.Var:
476+
x = self.env[node.args[0]]
477+
mask = self.env[node.args[1]]
478+
source = self.env[node.args[2]]
479+
ndim = len(mask.struct_info.shape)
480+
if ndim == 1:
481+
index = self.block_builder.emit(relax.op.cumsum(mask, 0, dtype="int32"))
482+
index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32")))
483+
gathered_source = self.block_builder.emit(relax.op.take(source, index, axis=0))
484+
else:
485+
f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1]))
486+
index = self.block_builder.emit(relax.op.cumsum(f_mask, 0, dtype="int32"))
487+
index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32")))
488+
source_shape = [-1] + [
489+
s for idx, s in enumerate(source.struct_info.shape) if idx >= ndim
490+
]
491+
f_source = self.block_builder.emit(relax.op.reshape(source, source_shape))
492+
gathered_source = self.block_builder.emit(relax.op.take(f_source, index, axis=0))
493+
gathered_source = self.block_builder.emit(
494+
relax.op.reshape(gathered_source, x.struct_info.shape)
495+
)
496+
if ndim != len(x.struct_info.shape):
497+
mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape))
498+
return self.block_builder.emit(relax.op.where(mask, gathered_source, x))
499+
475500
def _ones(self, node: fx.Node) -> relax.Var:
476501
import torch
477502

@@ -695,6 +720,7 @@ def create_convert_map(
695720
"index_select": self._index_select,
696721
"masked_fill_": self._inplace_masked_fill,
697722
"masked_fill": self._masked_fill,
723+
"masked_scatter": self._masked_scatter,
698724
"new_ones": self._new_ones,
699725
"ones": self._ones,
700726
"tensor": self._tensor,

src/contrib/msc/core/ir/graph_builder.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,9 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const S
704704
}
705705

706706
void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
707-
AddNode(GetRef<relax::Constant>(op));
707+
if (!expr_tensor_map_.count(GetRef<relax::Constant>(op))) {
708+
AddNode(GetRef<relax::Constant>(op));
709+
}
708710
}
709711

710712
void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,

src/contrib/msc/core/transform/set_expr_layout.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,16 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call,
492492
return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs());
493493
}
494494
if (indices_layout->layout.defined()) {
495-
size_t indices_size = indices_layout->layout.ndim();
496-
LayoutDecision output_layout =
497-
LayoutUtils::ExpandLayout(indices_layout, std::vector<size_t>{indices_size});
495+
std::vector<size_t> expand_axes;
496+
for (size_t i = indices_layout->layout.ndim(); i < output_shape.size(); i++) {
497+
expand_axes.push_back(i);
498+
}
499+
LayoutDecision output_layout;
500+
if (expand_axes.size() == 0) {
501+
output_layout = indices_layout;
502+
} else {
503+
output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes);
504+
}
498505
return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs());
499506
}
500507
return InferLayoutOutput();

src/contrib/msc/framework/torch/torch_opcode.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ class TorchConstantCodeGen : public TorchOpCode {
224224
} else if (dtype == "float32") {
225225
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
226226
}
227+
} else if (dtype == "bool") {
228+
stack_.func_call("register_buffer", "", "self")
229+
.call_arg(DocUtils::ToStr(ref_name))
230+
.inplace_start("torch.BoolTensor")
231+
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
232+
.inplace_end();
227233
} else if (dtype == "int32") {
228234
stack_.func_call("register_buffer", "", "self")
229235
.call_arg(DocUtils::ToStr(ref_name))
@@ -658,6 +664,18 @@ class TorchStridedSliceCodeGen : public TorchOpCode {
658664
}
659665
};
660666

667+
class TorchTakeCodeGen : public TorchOpCode {
668+
TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen)
669+
670+
protected:
671+
void CodeGenForward() final {
672+
if (node()->InputAt(1)->DTypeName() == "int32") {
673+
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
674+
}
675+
stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1)));
676+
}
677+
};
678+
661679
class TorchTriCodeGen : public TorchOpCode {
662680
TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen)
663681

@@ -738,6 +756,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
738756
map->emplace("subtract", std::make_shared<TorchSimpleCodeGen>("", "torch.subtract"));
739757
map->emplace("tan", std::make_shared<TorchSimpleCodeGen>("", "torch.tan"));
740758
map->emplace("tanh", std::make_shared<TorchSimpleCodeGen>("", "torch.tanh"));
759+
map->emplace("where", std::make_shared<TorchSimpleCodeGen>("", "torch.where"));
741760

742761
// reduce ops
743762
map->emplace("max", std::make_shared<TorchReduceAxesCodeGen>("", "torch.max"));
@@ -771,6 +790,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
771790
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
772791
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
773792
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));
793+
map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
774794

775795
// create ops
776796
map->emplace("constant", std::make_shared<TorchConstantCodeGen>("nn.Parameter", ""));

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,6 +2472,91 @@ def forward(self, data, index, src):
24722472
)
24732473

24742474

2475+
@pytest.mark.parametrize("dynamic", [True, False])
2476+
def test_masked_scatter(dynamic):
2477+
"""test graph builder for masked_scatter"""
2478+
2479+
dim = "dim" if dynamic else 5
2480+
2481+
class MaskedScatter1(Module):
2482+
def forward(self, data, mask, src):
2483+
return data.masked_scatter(mask, src)
2484+
2485+
class MaskedScatter2(Module):
2486+
def forward(self, data, mask, src):
2487+
return data.masked_scatter(mask, src)
2488+
2489+
expected1 = {
2490+
"inputs": [
2491+
{"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": "A"},
2492+
{"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"},
2493+
{"name": "inp_2", "shape": [10], "dtype": "float32", "layout": "A"},
2494+
],
2495+
"outputs": [{"name": "where", "shape": [dim], "dtype": "float32", "layout": "A"}],
2496+
"nodes": {
2497+
"total": 8,
2498+
"input": 3,
2499+
"cumsum": 1,
2500+
"constant": 1,
2501+
"subtract": 1,
2502+
"take": 1,
2503+
"where": 1,
2504+
},
2505+
}
2506+
expected2 = {
2507+
"inputs": [
2508+
{
2509+
"name": "inp_0",
2510+
"shape": [2, dim],
2511+
"dtype": "float32",
2512+
"layout": "" if dynamic else "BA",
2513+
},
2514+
{
2515+
"name": "inp_1",
2516+
"shape": [2, dim],
2517+
"dtype": "bool",
2518+
"layout": "" if dynamic else "BA",
2519+
},
2520+
{
2521+
"name": "inp_2",
2522+
"shape": [3, dim],
2523+
"dtype": "float32",
2524+
"layout": "" if dynamic else "BA",
2525+
},
2526+
],
2527+
"outputs": [
2528+
{
2529+
"name": "where",
2530+
"shape": [2, dim],
2531+
"dtype": "float32",
2532+
"layout": "" if dynamic else "BA",
2533+
}
2534+
],
2535+
"nodes": {
2536+
"total": 11,
2537+
"input": 3,
2538+
"reshape": 3,
2539+
"cumsum": 1,
2540+
"constant": 1,
2541+
"subtract": 1,
2542+
"take": 1,
2543+
"where": 1,
2544+
},
2545+
}
2546+
if dynamic:
2547+
expected1["prims"] = {"total": 1, "shape": 1}
2548+
expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2}
2549+
2550+
verify_model(
2551+
MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], "float32")], expected1
2552+
)
2553+
verify_model(
2554+
MaskedScatter2(),
2555+
[([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")],
2556+
expected2,
2557+
)
2558+
2559+
24752560
def test_put():
24762561
"""test graph builder for index_put"""
24772562

tests/python/contrib/test_msc/test_translate_relax.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,29 @@ def forward(self, data, index, src):
11931193
verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")])
11941194

11951195

1196+
def test_masked_scatter():
1197+
"""test relax translator for masked_scatter"""
1198+
1199+
class MaskedScatter1(Module):
1200+
def __init__(self):
1201+
super().__init__()
1202+
self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH)
1203+
1204+
def forward(self, data, src):
1205+
return data.masked_scatter(self.mask, src)
1206+
1207+
class MaskedScatter2(Module):
1208+
def __init__(self):
1209+
super().__init__()
1210+
self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH)
1211+
1212+
def forward(self, data, src):
1213+
return data.masked_scatter(self.mask, src)
1214+
1215+
verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")])
1216+
verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")])
1217+
1218+
11961219
def test_put():
11971220
"""test relax translator for index_put"""
11981221

tests/python/contrib/test_msc/test_translate_torch.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,29 @@ def forward(self, data, index, src):
11731173
)
11741174

11751175

1176+
def test_masked_scatter():
1177+
"""test torch translator for masked_scatter"""
1178+
1179+
class MaskedScatter1(Module):
1180+
def __init__(self):
1181+
super().__init__()
1182+
self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH)
1183+
1184+
def forward(self, data, src):
1185+
return data.masked_scatter(self.mask, src)
1186+
1187+
class MaskedScatter2(Module):
1188+
def __init__(self):
1189+
super().__init__()
1190+
self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH)
1191+
1192+
def forward(self, data, src):
1193+
return data.masked_scatter(self.mask, src)
1194+
1195+
verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True)
1196+
verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")], True)
1197+
1198+
11761199
def test_put():
11771200
"""test torch translator for index_put"""
11781201

tests/python/relax/test_frontend_from_fx.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4023,5 +4023,66 @@ def main(
40234023
verify_model(Scatter(), input_info, {}, expected)
40244024

40254025

4026+
def test_masked_scatter():
4027+
class MaskedScatter1(Module):
4028+
def forward(self, data, mask, src):
4029+
return data.masked_scatter(mask, src)
4030+
4031+
class MaskedScatter2(Module):
4032+
def forward(self, data, mask, src):
4033+
return data.masked_scatter(mask, src)
4034+
4035+
@tvm.script.ir_module
4036+
class expected1:
4037+
@R.function
4038+
def main(
4039+
inp_0: R.Tensor((5,), dtype="float32"),
4040+
inp_1: R.Tensor((5,), dtype="bool"),
4041+
inp_2: R.Tensor((10,), dtype="float32"),
4042+
) -> R.Tensor((5,), dtype="float32"):
4043+
with R.dataflow():
4044+
lv: R.Tensor((5,), dtype="int32") = R.cumsum(
4045+
inp_1, axis=0, dtype="int32", exclusive=False
4046+
)
4047+
lv1: R.Tensor((5,), dtype="int32") = R.subtract(lv, R.const(1, "int32"))
4048+
lv2: R.Tensor((5,), dtype="float32") = R.take(inp_2, lv1, axis=0)
4049+
lv3: R.Tensor((5,), dtype="float32") = R.where(inp_1, lv2, inp_0)
4050+
gv: R.Tensor((5,), dtype="float32") = lv3
4051+
R.output(gv)
4052+
return gv
4053+
4054+
@tvm.script.ir_module
4055+
class expected2:
4056+
@R.function
4057+
def main(
4058+
inp_0: R.Tensor((2, 5), dtype="float32"),
4059+
inp_1: R.Tensor((2, 5), dtype="bool"),
4060+
inp_2: R.Tensor((3, 5), dtype="float32"),
4061+
) -> R.Tensor((2, 5), dtype="float32"):
4062+
with R.dataflow():
4063+
lv: R.Tensor((10,), dtype="bool") = R.reshape(inp_1, R.shape([10]))
4064+
lv1: R.Tensor((10,), dtype="int32") = R.cumsum(
4065+
lv, axis=0, dtype="int32", exclusive=False
4066+
)
4067+
lv2: R.Tensor((10,), dtype="int32") = R.subtract(lv1, R.const(1, "int32"))
4068+
lv3: R.Tensor((15,), dtype="float32") = R.reshape(inp_2, R.shape([15]))
4069+
lv4: R.Tensor((10,), dtype="float32") = R.take(lv3, lv2, axis=0)
4070+
lv5: R.Tensor((2, 5), dtype="float32") = R.reshape(lv4, R.shape([2, 5]))
4071+
lv6: R.Tensor((2, 5), dtype="float32") = R.where(inp_1, lv5, inp_0)
4072+
gv: R.Tensor((2, 5), dtype="float32") = lv6
4073+
R.output(gv)
4074+
return gv
4075+
4076+
verify_model(
4077+
MaskedScatter1(), [([5], "float32"), ([5], "bool"), ([10], "float32")], {}, expected1
4078+
)
4079+
verify_model(
4080+
MaskedScatter2(),
4081+
[([2, 5], "float32"), ([2, 5], "bool"), ([3, 5], "float32")],
4082+
{},
4083+
expected2,
4084+
)
4085+
4086+
40264087
if __name__ == "__main__":
40274088
tvm.testing.main()

0 commit comments

Comments
 (0)