@@ -232,7 +232,31 @@ def test_saturation():
232232 np .testing .assert_equal (op_res .numpy (), golden_output )
233233
234234
235+ def test_ignore_broadcast_axis ():
236+ data_dtype = "uint8"
237+
238+ x = relay .var ("x" , shape = (4 ,), dtype = data_dtype )
239+ y = relay .var ("y" , shape = (4 ,), dtype = data_dtype )
240+ z = relay .qnn .op .add (
241+ lhs = x ,
242+ rhs = y ,
243+ lhs_scale = relay .const (0.00784314 , "float32" ),
244+ lhs_zero_point = relay .const (127 , "int32" ),
245+ rhs_scale = relay .const (0.00784314 , "float32" ),
246+ rhs_zero_point = relay .const (127 , "int32" ),
247+ output_scale = relay .const (0.00784314 , "float32" ),
248+ output_zero_point = relay .const (127 , "int32" ),
249+ lhs_axis = 1 ,
250+ rhs_axis = 1 ,
251+ )
252+
253+ func = relay .Function ([x , y ], z )
254+ mod = tvm .IRModule .from_expr (func )
255+ mod = relay .transform .InferType ()(mod )
256+
257+
235258if __name__ == "__main__" :
236259 test_tflite_same_io_qnn_params ()
237260 test_tflite_different_io_qnn_params ()
238261 test_saturation ()
262+ test_ignore_broadcast_axis ()
0 commit comments