1515from sglang .jit_kernel .hicache import (
1616 transfer_hicache_one_layer as jit_transfer_hicache_one_layer ,
1717)
18- from sglang .srt .mem_cache .memory_pool import KVCache , MHATokenToKVPool , MLATokenToKVPool
18+ from sglang .srt .mem_cache .memory_pool import (
19+ KVCache ,
20+ MHATokenToKVPool ,
21+ MLATokenToKVPool ,
22+ NSATokenToKVPool ,
23+ )
1924from sglang .srt .utils import is_cuda , is_npu , is_xpu
2025
2126_is_cuda = is_cuda ()
@@ -689,7 +694,9 @@ def __init__(
689694 pin_memory : bool = True ,
690695 device : str = "cpu" ,
691696 allocator_type : str = "default" ,
697+ override_kv_cache_dim : Optional [int ] = None ,
692698 ):
699+ self .override_kv_cache_dim = override_kv_cache_dim
693700 super ().__init__ (
694701 device_pool ,
695702 host_to_device_ratio ,
@@ -711,13 +718,10 @@ def get_size_per_token(self):
711718 self .kv_lora_rank = self .device_pool .kv_lora_rank
712719 self .qk_rope_head_dim = self .device_pool .qk_rope_head_dim
713720 self .layer_num = self .device_pool .layer_num
714-
715- return (
716- (self .kv_lora_rank + self .qk_rope_head_dim )
717- * 1
718- * self .dtype .itemsize
719- * self .layer_num
721+ self .kv_cache_dim = self .override_kv_cache_dim or (
722+ self .kv_lora_rank + self .qk_rope_head_dim
720723 )
724+ return self .kv_cache_dim * self .dtype .itemsize * self .layer_num
721725
722726 def get_ksize_per_token (self ):
723727 return self .get_size_per_token ()
@@ -728,22 +732,22 @@ def init_kv_buffer(self):
728732 self .layer_num ,
729733 self .size ,
730734 1 ,
731- self .kv_lora_rank + self . qk_rope_head_dim ,
735+ self .kv_cache_dim ,
732736 )
733737 elif self .layout == "page_first" :
734738 dims = (
735739 self .size ,
736740 self .layer_num ,
737741 1 ,
738- self .kv_lora_rank + self . qk_rope_head_dim ,
742+ self .kv_cache_dim ,
739743 )
740744 elif self .layout == "page_first_direct" :
741745 dims = (
742746 self .page_num ,
743747 self .layer_num ,
744748 self .page_size ,
745749 1 ,
746- self .kv_lora_rank + self . qk_rope_head_dim ,
750+ self .kv_cache_dim ,
747751 )
748752 # Ascend-specific: Aligns with NPUMLATokenToKVPool layout
749753 # Separately allocate k_buffer and v_buffer for easier data transfer.
@@ -774,9 +778,7 @@ def init_kv_buffer(self):
774778 return self .k_buffer
775779 else :
776780 raise ValueError (f"Unsupported layout: { self .layout } " )
777- self .token_stride_size = (
778- self .kv_lora_rank + self .qk_rope_head_dim
779- ) * self .dtype .itemsize
781+ self .token_stride_size = self .kv_cache_dim * self .dtype .itemsize
780782 self .layout_dim = self .token_stride_size * self .layer_num
781783
782784 alloc_func = ALLOC_MEMORY_FUNCS [self .device_pool .device ]
@@ -933,7 +935,7 @@ def get_dummy_flat_data_page(self) -> torch.Tensor:
933935 self .layer_num ,
934936 self .page_size ,
935937 1 ,
936- self .kv_lora_rank + self . qk_rope_head_dim ,
938+ self .kv_cache_dim ,
937939 ),
938940 dtype = self .dtype ,
939941 device = self .device ,
@@ -946,14 +948,14 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
946948 self .layer_num ,
947949 self .page_size ,
948950 1 ,
949- self .kv_lora_rank + self . qk_rope_head_dim ,
951+ self .kv_cache_dim ,
950952 )
951953 elif self .layout == "page_first" :
952954 self .kv_buffer [index : index + self .page_size , :, :, :] = data_page .reshape (
953955 self .page_size ,
954956 self .layer_num ,
955957 1 ,
956- self .kv_lora_rank + self . qk_rope_head_dim ,
958+ self .kv_cache_dim ,
957959 )
958960 elif self .layout == "page_first_direct" :
959961 real_index = index // self .page_size
@@ -962,7 +964,7 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
962964 self .layer_num ,
963965 self .page_size ,
964966 1 ,
965- self .kv_lora_rank + self . qk_rope_head_dim ,
967+ self .kv_cache_dim ,
966968 )
967969 else :
968970 raise ValueError (f"Unsupported layout: { self .layout } " )
@@ -980,38 +982,199 @@ def get_page_buffer_meta(self, indices):
980982 for layer_id in range (self .layer_num ):
981983 k_ptr = (
982984 kv_buffer_data_ptr
983- + indices [index ]
984- * (self .kv_lora_rank + self .qk_rope_head_dim )
985- * self .dtype .itemsize
986- + layer_id
987- * self .size
988- * (self .kv_lora_rank + self .qk_rope_head_dim )
989- * self .dtype .itemsize
985+ + indices [index ] * self .kv_cache_dim * self .dtype .itemsize
986+ + layer_id * self .size * self .kv_cache_dim * self .dtype .itemsize
990987 )
991988 ptr_list .append (k_ptr )
992- element_size = (
993- self .dtype .itemsize
994- * self .page_size
995- * (self .kv_lora_rank + self .qk_rope_head_dim )
996- )
989+ element_size = self .dtype .itemsize * self .page_size * self .kv_cache_dim
997990 element_size_list = [element_size ] * len (ptr_list )
998991 elif self .layout in ["page_first" , "page_first_direct" ]:
999992 for index in range (0 , len (indices ), self .page_size ):
1000993 k_ptr = (
1001994 kv_buffer_data_ptr
1002995 + indices [index ]
1003996 * self .layer_num
1004- * ( self .kv_lora_rank + self . qk_rope_head_dim )
997+ * self .kv_cache_dim
1005998 * self .dtype .itemsize
1006999 )
10071000 ptr_list .append (k_ptr )
10081001 element_size = (
10091002 self .layer_num
10101003 * self .dtype .itemsize
10111004 * self .page_size
1012- * ( self .kv_lora_rank + self . qk_rope_head_dim )
1005+ * self .kv_cache_dim
10131006 )
10141007 element_size_list = [element_size ] * len (ptr_list )
10151008 else :
10161009 raise ValueError (f"Unsupported layout: { self .layout } " )
10171010 return ptr_list , element_size_list
1011+
1012+
1013+ class NSATokenToKVPoolHost (MLATokenToKVPoolHost ):
1014+ device_pool : NSATokenToKVPool
1015+
1016+ def __init__ (
1017+ self ,
1018+ device_pool : NSATokenToKVPool ,
1019+ host_to_device_ratio : float ,
1020+ host_size : int ,
1021+ page_size : int ,
1022+ layout : str ,
1023+ pin_memory : bool = True ,
1024+ device : str = "cpu" ,
1025+ allocator_type : str = "default" ,
1026+ ):
1027+ # Initialize indexer metadata before HostKVCache.__init__ calls get_size_per_token.
1028+ self .index_head_dim = device_pool .index_head_dim
1029+ self .indexer_quant_block_size = device_pool .quant_block_size
1030+ self .indexer_dtype = NSATokenToKVPool .index_k_with_scale_buffer_dtype
1031+ self .indexer_size_per_token = (
1032+ self .index_head_dim
1033+ + self .index_head_dim // self .indexer_quant_block_size * 4
1034+ )
1035+ super ().__init__ (
1036+ device_pool ,
1037+ host_to_device_ratio ,
1038+ host_size ,
1039+ page_size ,
1040+ layout ,
1041+ pin_memory ,
1042+ device ,
1043+ allocator_type ,
1044+ override_kv_cache_dim = device_pool .kv_cache_dim ,
1045+ )
1046+ self .indexer_page_stride_size = (
1047+ self .indexer_size_per_token * self .page_size * self .indexer_dtype .itemsize
1048+ )
1049+ self .indexer_page_num = (self .size + self .page_size + 1 ) // self .page_size
1050+ self ._init_indexer_buffers ()
1051+ logger .info (
1052+ f"NSATokenToKVPoolHost initialized with indexer page stride size: { self .indexer_page_stride_size } , page num: { self .indexer_page_num } "
1053+ )
1054+
1055+ def get_size_per_token (self ):
1056+ base = super ().get_size_per_token ()
1057+ return (
1058+ base
1059+ + self .indexer_size_per_token * self .layer_num * self .indexer_dtype .itemsize
1060+ )
1061+
1062+ def _init_indexer_buffers (self ):
1063+ alloc_func = ALLOC_MEMORY_FUNCS [self .device_pool .device ]
1064+ self .index_k_with_scale_buffer = [
1065+ alloc_func (
1066+ (self .indexer_page_num , self .indexer_page_stride_size ),
1067+ dtype = self .indexer_dtype ,
1068+ device = self .device ,
1069+ pin_memory = self .pin_memory ,
1070+ allocator = self .allocator ,
1071+ )
1072+ for _ in range (self .layer_num )
1073+ ]
1074+ self .index_k_data_refs = [
1075+ self .index_k_with_scale_buffer [i ] for i in range (self .layer_num )
1076+ ]
1077+ self .index_k_data_ptrs = torch .tensor (
1078+ [x .data_ptr () for x in self .index_k_data_refs ],
1079+ dtype = torch .uint64 ,
1080+ device = self .device_pool .device ,
1081+ )
1082+ self .index_k_device_ptrs = torch .tensor (
1083+ [x .data_ptr () for x in self .device_pool .index_k_with_scale_buffer ],
1084+ dtype = torch .uint64 ,
1085+ device = self .device_pool .device ,
1086+ )
1087+
1088+ def _get_indexer_page_indices (self , host_indices , device_indices ):
1089+ if host_indices .numel () == 0 :
1090+ return host_indices , device_indices
1091+ if host_indices .numel () % self .page_size != 0 :
1092+ raise ValueError (
1093+ "Index buffer transfer expects page-aligned indices for NSA."
1094+ )
1095+ host_page_indices = (
1096+ host_indices .reshape (- 1 , self .page_size )[:, 0 ] // self .page_size
1097+ )
1098+ device_page_indices = (
1099+ device_indices .reshape (- 1 , self .page_size )[:, 0 ] // self .page_size
1100+ )
1101+ return host_page_indices , device_page_indices
1102+
1103+ def _load_indexer_to_device_per_layer (
1104+ self , device_pool , host_indices , device_indices , layer_id , io_backend
1105+ ):
1106+ host_page_indices , device_page_indices = self ._get_indexer_page_indices (
1107+ host_indices , device_indices
1108+ )
1109+ use_kernel = io_backend == "kernel" and self .indexer_page_stride_size % 8 == 0
1110+ if use_kernel :
1111+ transfer_kv_per_layer_mla (
1112+ src = self .index_k_with_scale_buffer [layer_id ],
1113+ dst = device_pool .index_k_with_scale_buffer [layer_id ],
1114+ src_indices = host_page_indices ,
1115+ dst_indices = device_page_indices ,
1116+ item_size = self .indexer_page_stride_size ,
1117+ )
1118+ else :
1119+ transfer_kv_direct (
1120+ src_layers = [self .index_k_with_scale_buffer [layer_id ]],
1121+ dst_layers = [device_pool .index_k_with_scale_buffer [layer_id ]],
1122+ src_indices = host_page_indices ,
1123+ dst_indices = device_page_indices ,
1124+ page_size = 1 ,
1125+ )
1126+
1127+ if layer_id == 0 :
1128+ logger .info (
1129+ f"NSATokenToKVPoolHost loaded indexer to device for layer { layer_id } , use_kernel: { use_kernel } , host_page_indices: { host_page_indices } , device_page_indices: { device_page_indices } "
1130+ )
1131+
1132+ def _backup_indexer_from_device_all_layer (
1133+ self , device_pool , host_indices , device_indices , io_backend
1134+ ):
1135+ host_page_indices , device_page_indices = self ._get_indexer_page_indices (
1136+ host_indices , device_indices
1137+ )
1138+ use_kernel = io_backend == "kernel" and self .indexer_page_stride_size % 8 == 0
1139+ if use_kernel :
1140+ transfer_kv_all_layer_mla (
1141+ src_layers = self .index_k_device_ptrs ,
1142+ dst_layers = self .index_k_data_ptrs ,
1143+ src_indices = device_page_indices ,
1144+ dst_indices = host_page_indices ,
1145+ item_size = self .indexer_page_stride_size ,
1146+ num_layers = self .layer_num ,
1147+ )
1148+ else :
1149+ transfer_kv_direct (
1150+ src_layers = device_pool .index_k_with_scale_buffer ,
1151+ dst_layers = self .index_k_with_scale_buffer ,
1152+ src_indices = device_page_indices ,
1153+ dst_indices = host_page_indices ,
1154+ page_size = 1 ,
1155+ )
1156+
1157+ def load_to_device_per_layer (
1158+ self ,
1159+ device_pool ,
1160+ host_indices ,
1161+ device_indices ,
1162+ layer_id ,
1163+ io_backend ,
1164+ ):
1165+ super ().load_to_device_per_layer (
1166+ device_pool , host_indices , device_indices , layer_id , io_backend
1167+ )
1168+ self ._load_indexer_to_device_per_layer (
1169+ device_pool , host_indices , device_indices , layer_id , io_backend
1170+ )
1171+
1172+ def backup_from_device_all_layer (
1173+ self , device_pool , host_indices , device_indices , io_backend
1174+ ):
1175+ super ().backup_from_device_all_layer (
1176+ device_pool , host_indices , device_indices , io_backend
1177+ )
1178+ self ._backup_indexer_from_device_all_layer (
1179+ device_pool , host_indices , device_indices , io_backend
1180+ )
0 commit comments