Skip to content

Commit c286638

Browse files
author
Siyuan Feng
authored
[relax] Fix tree attention for Qwen2-1.5 models (#17700)
Fix the compilation error for Qwen2-1.5 models in the tree attention implementation for vulkan backend.
1 parent dcc8891 commit c286638

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

python/tvm/relax/frontend/nn/llm/tree_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches
425425
batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x)
426426

427427
if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1):
428-
b_idx: T.int32 = batch_idx[0]
429-
LH_start: T.int32 = tile_id[0] * tile_x
428+
b_idx: T.int32(is_size_var=True) = batch_idx[0]
429+
LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x
430430
q_indptr_val: T.int32 = q_indptr[b_idx]
431431

432432
kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
@@ -1049,8 +1049,8 @@ def tree_attn_paged_kv(
10491049
batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x)
10501050

10511051
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
1052-
b_idx: T.int32 = batch_idx[0]
1053-
LH_start: T.int32 = tile_id[0] * tile_x
1052+
b_idx: T.int32(is_size_var=True) = batch_idx[0]
1053+
LH_start: T.int32(is_size_var=True) = tile_id[0] * tile_x
10541054
q_indptr_val: T.int32 = q_indptr[b_idx]
10551055

10561056
cur_page_indptr_begin: T.int32 = page_indptr[b_idx]

0 commit comments

Comments
 (0)