@@ -4023,5 +4023,66 @@ def main(
40234023 verify_model (Scatter (), input_info , {}, expected )
40244024
40254025
4026+ def test_masked_scatter ():
4027+ class MaskedScatter1 (Module ):
4028+ def forward (self , data , mask , src ):
4029+ return data .masked_scatter (mask , src )
4030+
4031+ class MaskedScatter2 (Module ):
4032+ def forward (self , data , mask , src ):
4033+ return data .masked_scatter (mask , src )
4034+
4035+ @tvm .script .ir_module
4036+ class expected1 :
4037+ @R .function
4038+ def main (
4039+ inp_0 : R .Tensor ((5 ,), dtype = "float32" ),
4040+ inp_1 : R .Tensor ((5 ,), dtype = "bool" ),
4041+ inp_2 : R .Tensor ((10 ,), dtype = "float32" ),
4042+ ) -> R .Tensor ((5 ,), dtype = "float32" ):
4043+ with R .dataflow ():
4044+ lv : R .Tensor ((5 ,), dtype = "int32" ) = R .cumsum (
4045+ inp_1 , axis = 0 , dtype = "int32" , exclusive = False
4046+ )
4047+ lv1 : R .Tensor ((5 ,), dtype = "int32" ) = R .subtract (lv , R .const (1 , "int32" ))
4048+ lv2 : R .Tensor ((5 ,), dtype = "float32" ) = R .take (inp_2 , lv1 , axis = 0 )
4049+ lv3 : R .Tensor ((5 ,), dtype = "float32" ) = R .where (inp_1 , lv2 , inp_0 )
4050+ gv : R .Tensor ((5 ,), dtype = "float32" ) = lv3
4051+ R .output (gv )
4052+ return gv
4053+
4054+ @tvm .script .ir_module
4055+ class expected2 :
4056+ @R .function
4057+ def main (
4058+ inp_0 : R .Tensor ((2 , 5 ), dtype = "float32" ),
4059+ inp_1 : R .Tensor ((2 , 5 ), dtype = "bool" ),
4060+ inp_2 : R .Tensor ((3 , 5 ), dtype = "float32" ),
4061+ ) -> R .Tensor ((2 , 5 ), dtype = "float32" ):
4062+ with R .dataflow ():
4063+ lv : R .Tensor ((10 ,), dtype = "bool" ) = R .reshape (inp_1 , R .shape ([10 ]))
4064+ lv1 : R .Tensor ((10 ,), dtype = "int32" ) = R .cumsum (
4065+ lv , axis = 0 , dtype = "int32" , exclusive = False
4066+ )
4067+ lv2 : R .Tensor ((10 ,), dtype = "int32" ) = R .subtract (lv1 , R .const (1 , "int32" ))
4068+ lv3 : R .Tensor ((15 ,), dtype = "float32" ) = R .reshape (inp_2 , R .shape ([15 ]))
4069+ lv4 : R .Tensor ((10 ,), dtype = "float32" ) = R .take (lv3 , lv2 , axis = 0 )
4070+ lv5 : R .Tensor ((2 , 5 ), dtype = "float32" ) = R .reshape (lv4 , R .shape ([2 , 5 ]))
4071+ lv6 : R .Tensor ((2 , 5 ), dtype = "float32" ) = R .where (inp_1 , lv5 , inp_0 )
4072+ gv : R .Tensor ((2 , 5 ), dtype = "float32" ) = lv6
4073+ R .output (gv )
4074+ return gv
4075+
4076+ verify_model (
4077+ MaskedScatter1 (), [([5 ], "float32" ), ([5 ], "bool" ), ([10 ], "float32" )], {}, expected1
4078+ )
4079+ verify_model (
4080+ MaskedScatter2 (),
4081+ [([2 , 5 ], "float32" ), ([2 , 5 ], "bool" ), ([3 , 5 ], "float32" )],
4082+ {},
4083+ expected2 ,
4084+ )
4085+
4086+
40264087if __name__ == "__main__" :
40274088 tvm .testing .main ()
0 commit comments