Skip to content

Commit 9d6741c

Browse files
committed
[QNN] Fix broadcast for invalid axis
1 parent 6ccde09 commit 9d6741c

2 files changed

Lines changed: 33 additions & 4 deletions

File tree

src/relay/qnn/op/op_common.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,19 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
255255
}
256256

257257
const BroadcastAttrs* broadcast_attrs = attrs.as<BroadcastAttrs>();
258-
int lhs_axis = broadcast_attrs->lhs_axis;
259-
int rhs_axis = broadcast_attrs->rhs_axis;
258+
ICHECK(broadcast_attrs);
260259

261260
auto lhs_rank = static_cast<int>(lhs_data->shape.size());
262261
auto rhs_rank = static_cast<int>(rhs_data->shape.size());
263262

264-
lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : lhs_axis;
265-
rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) : rhs_axis;
263+
auto get_broadcast_axis = [](int rank, int axis_from_attr) {
264+
if (rank <= 1) return 0;
265+
if (axis_from_attr < 0) return rank + axis_from_attr;
266+
return axis_from_attr;
267+
};
268+
269+
const int lhs_axis = get_broadcast_axis(lhs_rank, broadcast_attrs->lhs_axis);
270+
const int rhs_axis = get_broadcast_axis(rhs_rank, broadcast_attrs->rhs_axis);
266271

267272
// If zero point and scale are scalar then axis doesn't matter.
268273
bool lhs_scale_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;

tests/python/relay/test_op_qnn_add.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,31 @@ def test_saturation():
232232
np.testing.assert_equal(op_res.numpy(), golden_output)
233233

234234

235+
def test_ignore_broadcast_axis():
236+
data_dtype = "uint8"
237+
238+
x = relay.var("x", shape=(4,), dtype=data_dtype)
239+
y = relay.var("y", shape=(4,), dtype=data_dtype)
240+
z = relay.qnn.op.add(
241+
lhs=x,
242+
rhs=y,
243+
lhs_scale=relay.const(0.00784314, "float32"),
244+
lhs_zero_point=relay.const(127, "int32"),
245+
rhs_scale=relay.const(0.00784314, "float32"),
246+
rhs_zero_point=relay.const(127, "int32"),
247+
output_scale=relay.const(0.00784314, "float32"),
248+
output_zero_point=relay.const(127, "int32"),
249+
lhs_axis=1,
250+
rhs_axis=1,
251+
)
252+
253+
func = relay.Function([x, y], z)
254+
mod = tvm.IRModule.from_expr(func)
255+
mod = relay.transform.InferType()(mod)
256+
257+
235258
if __name__ == "__main__":
236259
test_tflite_same_io_qnn_params()
237260
test_tflite_different_io_qnn_params()
238261
test_saturation()
262+
test_ignore_broadcast_axis()

0 commit comments

Comments
 (0)