Skip to content

Commit 5534959

Browse files
rebel-shshinjunrushao
authored andcommitted
[FRONTEND][TFLITE][BugFix] Fix variable typo in batchmatmul converting func (apache#15259)
* TFLite frontend bug fix * Update tflite.py * lint * Add pytest
1 parent 36f72d8 commit 5534959

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

python/tvm/relay/frontend/tflite.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,7 +3006,9 @@ def convert_batch_matmul(self, op):
30063006
rank_diff = rank_a - rank_b
30073007
new_b_shape = _op.concatenate(
30083008
[
3009-
_expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype),
3009+
_expr.const(
3010+
[1] * rank_diff, dtype=_infer_type(new_b_shape).checked_type.dtype
3011+
),
30103012
shape_b,
30113013
],
30123014
0,
@@ -3015,7 +3017,9 @@ def convert_batch_matmul(self, op):
30153017
rank_diff = rank_b - rank_a
30163018
new_a_shape = _op.concatenate(
30173019
[
3018-
_expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype),
3020+
_expr.const(
3021+
[1] * rank_diff, dtype=_infer_type(new_a_shape).checked_type.dtype
3022+
),
30193023
shape_a,
30203024
],
30213025
0,
@@ -3041,9 +3045,9 @@ def convert_batch_matmul(self, op):
30413045
_op.concatenate([out_batch, _op.strided_slice(shape_b, [rank_b - 2], [rank_b])], 0)
30423046
)
30433047
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
3044-
input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
3048+
input_a = _op.transform.broadcast_to(input_a, a_broadcasted_shape)
30453049
if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
3046-
input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
3050+
input_b = _op.transform.broadcast_to(input_b, b_broadcasted_shape)
30473051

30483052
input_a = self.flatten_to_nd(input_a, shape_a, 3)
30493053
input_b = self.flatten_to_nd(input_b, shape_b, 3)

tests/python/frontend/tflite/test_forward.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,15 @@ def test_forward_batch_matmul(config):
794794
adjoint_b=False,
795795
quantized=config[2],
796796
)
797+
_test_batch_matmul(
798+
(2, 3, 5, 4),
799+
(1, 3, 5, 4),
800+
dtype=config[0],
801+
out_dtype=config[1],
802+
adjoint_a=True,
803+
adjoint_b=False,
804+
quantized=config[2],
805+
)
797806
_test_batch_matmul(
798807
(3, 5, 4),
799808
(3, 5, 4),
@@ -803,6 +812,15 @@ def test_forward_batch_matmul(config):
803812
adjoint_b=True,
804813
quantized=config[2],
805814
)
815+
_test_batch_matmul(
816+
(2, 3, 5, 4),
817+
(1, 3, 5, 4),
818+
dtype=config[0],
819+
out_dtype=config[1],
820+
adjoint_a=False,
821+
adjoint_b=True,
822+
quantized=config[2],
823+
)
806824
_test_batch_matmul(
807825
(3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
808826
)

0 commit comments

Comments
 (0)