2020# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name
2121import enum
2222import math
23- from typing import Any , Dict , List , Literal , Optional , Tuple
23+ from typing import Any , Dict , List , Literal , Optional , Tuple , Union
2424
2525import tvm
2626from tvm import relax as rx
@@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum):
8686
8787 MHA = 0
8888 MLA = 1
89+ MHA_SLIDING = 3
8990
9091
9192class RopeMode (enum .IntEnum ):
@@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me
301302
302303 def __init__ ( # pylint: disable=too-many-locals
303304 self ,
304- attn_kind : Literal ["mha" , "mla" ],
305+ attn_kind : Union [ Literal ["mha" , "mla" ], List [ Literal [ "mha" , "mla" , "mha_sliding" ]] ],
305306 max_batch_size : tir .Var ,
306307 max_total_seq_len : tir .Var ,
307308 prefill_chunk_size : tir .Var ,
@@ -377,8 +378,16 @@ def __init__( # pylint: disable=too-many-locals
377378 dtype_q = dtype ,
378379 dtype_kv = dtype ,
379380 dtype_o = dtype ,
380- qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim ,
381- v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim ,
381+ qk_head_dim = (
382+ qk_head_dim
383+ if (attn_kind == "mha" or isinstance (attn_kind , List ))
384+ else mla_original_qk_head_dim
385+ ),
386+ v_head_dim = (
387+ v_head_dim
388+ if (attn_kind == "mha" or isinstance (attn_kind , List ))
389+ else mla_original_v_head_dim
390+ ),
382391 target = target ,
383392 enable_inline_rope = rope_mode == RopeMode .INLINE ,
384393 )
@@ -391,7 +400,7 @@ def __init__( # pylint: disable=too-many-locals
391400 v_head_dim = v_head_dim ,
392401 target = target ,
393402 )
394- if attn_kind == "mha"
403+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
395404 else []
396405 )
397406 flashinfer_mla_mods = (
@@ -420,7 +429,7 @@ def __init__( # pylint: disable=too-many-locals
420429 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
421430 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
422431 ]
423- if attn_kind == "mha"
432+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
424433 else [rx .Tuple ([]) for _ in range (6 )]
425434 )
426435 mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_paged_attention_run" ), rx .ExternFunc ("batch_mla_paged_attention_plan" )] if attn_kind == "mla" else [])
@@ -430,6 +439,11 @@ def __init__( # pylint: disable=too-many-locals
430439 if attn_kind == "mla" :
431440 attn_merge_functions .append (bb .add_func (_merge_state_inplace (num_attention_heads , mla_original_v_head_dim , dtype , target , "tir_attention_merge_state_mla" ), "tir_attention_merge_state_mla" ))
432441
442+
443+ if isinstance (attn_kind , List ):
444+ attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
445+ else :
446+ attn_kind = [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
433447 args = [
434448 rx .ShapeExpr (
435449 [
@@ -482,7 +496,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods
482496
483497 def __init__ ( # pylint: disable=too-many-locals
484498 self ,
485- attn_kind : Literal ["mha" , "mla" ],
499+ attn_kind : Union [ Literal ["mha" , "mla" ], List [ Literal [ "mha" , "mla" , "mha_sliding" ]] ],
486500 max_batch_size : tir .Var ,
487501 max_total_seq_len : tir .Var ,
488502 prefill_chunk_size : tir .Var ,
@@ -553,7 +567,12 @@ def __init__( # pylint: disable=too-many-locals
553567 target : Target
554568 The target to build the model to.
555569 """
556-
570+ if isinstance (attn_kind , List ):
571+ attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
572+ else :
573+ attn_kind = [
574+ int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )
575+ ]
557576 bb = rx .BlockBuilder .current ()
558577 args = [
559578 rx .ShapeExpr (
@@ -570,9 +589,7 @@ def __init__( # pylint: disable=too-many-locals
570589 rx .PrimValue (num_key_value_heads ),
571590 rx .PrimValue (qk_head_dim ),
572591 rx .PrimValue (v_head_dim ),
573- rx .ShapeExpr (
574- [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
575- ),
592+ rx .ShapeExpr (attn_kind ),
576593 rx .PrimValue (enable_disaggregation ),
577594 rx .PrimValue (rope_mode ),
578595 rx .PrimValue (rope_scale ),
@@ -614,9 +631,9 @@ def __init__( # pylint: disable=too-many-locals
614631 else :
615632 # pylint: disable=line-too-long
616633 # fmt: off
617- ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim
618- ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim
619- args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if attn_kind == "mha" else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
634+ ragged_qk_head_dim = qk_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_qk_head_dim
635+ ragged_v_head_dim = v_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_v_head_dim
636+ args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
620637 mha_functions = (
621638 [
622639 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , False , rope_scaling , target ), "tir_attention_prefill" )]),
@@ -626,7 +643,7 @@ def __init__( # pylint: disable=too-many-locals
626643 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
627644 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
628645 ]
629- if attn_kind == "mha"
646+ if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
630647 else [rx .Tuple ([]) for _ in range (6 )]
631648 )
632649 mla_function = rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_mla (num_attention_heads , v_head_dim , qk_head_dim - v_head_dim , dtype , False , target ), "tir_attention_prefill_mla" )] if attn_kind == "mla" else [])
@@ -641,7 +658,7 @@ def __init__( # pylint: disable=too-many-locals
641658 [
642659 rx .Tuple (attn_merge_functions ),
643660 bb .add_func (llama_rope_with_position_map (rope_theta , rope_scale , qk_head_dim , num_attention_heads , num_key_value_heads , dtype , rope_scaling , rotary_dim ), "tir_split_rotary" ),
644- bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if attn_kind == "mha" else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
661+ bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
645662 bb .add_func (_kv_cache_debug_get_kv (num_hidden_layers , num_key_value_heads , qk_head_dim , dtype ), "kv_cache_debug_get_kv" ),
646663 bb .add_func (_compact_kv_copy (num_key_value_heads , qk_head_dim , dtype , target ), "kv_cache_compact_kv_copy" ),
647664 ]
0 commit comments