File tree Expand file tree Collapse file tree
python/tvm/relax/frontend/nn/llm Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -425,8 +425,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches
425425 batch_tiles [0 ] = T .ceildiv (batch_rows [0 ], tile_x )
426426
427427 if T .tvm_thread_invariant (batch_idx [0 ] < batch_size_plus_1 - 1 ):
428- b_idx : T .int32 = batch_idx [0 ]
429- LH_start : T .int32 = tile_id [0 ] * tile_x
428+ b_idx : T .int32 ( is_size_var = True ) = batch_idx [0 ]
429+ LH_start : T .int32 ( is_size_var = True ) = tile_id [0 ] * tile_x
430430 q_indptr_val : T .int32 = q_indptr [b_idx ]
431431
432432 kv_chunk_len [0 ] = kv_indptr [b_idx + 1 ] - kv_indptr [b_idx ]
@@ -1049,8 +1049,8 @@ def tree_attn_paged_kv(
10491049 batch_tiles [0 ] = T .ceildiv (batch_rows [0 ], tile_x )
10501050
10511051 if T .tvm_thread_invariant (batch_idx [0 ] < batch_size ):
1052- b_idx : T .int32 = batch_idx [0 ]
1053- LH_start : T .int32 = tile_id [0 ] * tile_x
1052+ b_idx : T .int32 ( is_size_var = True ) = batch_idx [0 ]
1053+ LH_start : T .int32 ( is_size_var = True ) = tile_id [0 ] * tile_x
10541054 q_indptr_val : T .int32 = q_indptr [b_idx ]
10551055
10561056 cur_page_indptr_begin : T .int32 = page_indptr [b_idx ]
You can’t perform that action at this time.
0 commit comments