Skip to content

Commit b6688cd

Browse files
hzh0425huangtingwei9988xiezhq-hermann
committed
Support v32 cpu offloading
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
1 parent 8fb4552 commit b6688cd

File tree

3 files changed

+339
-32
lines changed

3 files changed

+339
-32
lines changed

python/sglang/srt/mem_cache/hiradix_cache.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111

1212
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
1313
from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams, MatchResult
14-
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
14+
from sglang.srt.mem_cache.memory_pool import (
15+
MHATokenToKVPool,
16+
MLATokenToKVPool,
17+
NSATokenToKVPool,
18+
)
1519
from sglang.srt.mem_cache.memory_pool_host import (
1620
MHATokenToKVPoolHost,
1721
MLATokenToKVPoolHost,
22+
NSATokenToKVPoolHost,
1823
)
1924
from sglang.srt.mem_cache.radix_cache import (
2025
RadixCache,
@@ -59,6 +64,15 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
5964
server_args.hicache_mem_layout,
6065
allocator_type=server_args.hicache_storage_backend,
6166
)
67+
elif isinstance(self.kv_cache, NSATokenToKVPool):
68+
self.token_to_kv_pool_host = NSATokenToKVPoolHost(
69+
self.kv_cache,
70+
server_args.hicache_ratio,
71+
server_args.hicache_size,
72+
self.page_size,
73+
server_args.hicache_mem_layout,
74+
allocator_type=server_args.hicache_storage_backend,
75+
)
6276
elif isinstance(self.kv_cache, MLATokenToKVPool):
6377
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
6478
self.kv_cache,

python/sglang/srt/mem_cache/memory_pool_host.py

Lines changed: 194 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from sglang.jit_kernel.hicache import (
1616
transfer_hicache_one_layer as jit_transfer_hicache_one_layer,
1717
)
18-
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
18+
from sglang.srt.mem_cache.memory_pool import (
19+
KVCache,
20+
MHATokenToKVPool,
21+
MLATokenToKVPool,
22+
NSATokenToKVPool,
23+
)
1924
from sglang.srt.utils import is_cuda, is_npu, is_xpu
2025

2126
_is_cuda = is_cuda()
@@ -689,7 +694,9 @@ def __init__(
689694
pin_memory: bool = True,
690695
device: str = "cpu",
691696
allocator_type: str = "default",
697+
override_kv_cache_dim: Optional[int] = None,
692698
):
699+
self.override_kv_cache_dim = override_kv_cache_dim
693700
super().__init__(
694701
device_pool,
695702
host_to_device_ratio,
@@ -711,13 +718,10 @@ def get_size_per_token(self):
711718
self.kv_lora_rank = self.device_pool.kv_lora_rank
712719
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
713720
self.layer_num = self.device_pool.layer_num
714-
715-
return (
716-
(self.kv_lora_rank + self.qk_rope_head_dim)
717-
* 1
718-
* self.dtype.itemsize
719-
* self.layer_num
721+
self.kv_cache_dim = self.override_kv_cache_dim or (
722+
self.kv_lora_rank + self.qk_rope_head_dim
720723
)
724+
return self.kv_cache_dim * self.dtype.itemsize * self.layer_num
721725

722726
def get_ksize_per_token(self):
723727
return self.get_size_per_token()
@@ -728,22 +732,22 @@ def init_kv_buffer(self):
728732
self.layer_num,
729733
self.size,
730734
1,
731-
self.kv_lora_rank + self.qk_rope_head_dim,
735+
self.kv_cache_dim,
732736
)
733737
elif self.layout == "page_first":
734738
dims = (
735739
self.size,
736740
self.layer_num,
737741
1,
738-
self.kv_lora_rank + self.qk_rope_head_dim,
742+
self.kv_cache_dim,
739743
)
740744
elif self.layout == "page_first_direct":
741745
dims = (
742746
self.page_num,
743747
self.layer_num,
744748
self.page_size,
745749
1,
746-
self.kv_lora_rank + self.qk_rope_head_dim,
750+
self.kv_cache_dim,
747751
)
748752
# Ascend-specific: Aligns with NPUMLATokenToKVPool layout
749753
# Separately allocate k_buffer and v_buffer for easier data transfer.
@@ -774,9 +778,7 @@ def init_kv_buffer(self):
774778
return self.k_buffer
775779
else:
776780
raise ValueError(f"Unsupported layout: {self.layout}")
777-
self.token_stride_size = (
778-
self.kv_lora_rank + self.qk_rope_head_dim
779-
) * self.dtype.itemsize
781+
self.token_stride_size = self.kv_cache_dim * self.dtype.itemsize
780782
self.layout_dim = self.token_stride_size * self.layer_num
781783

782784
alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
@@ -933,7 +935,7 @@ def get_dummy_flat_data_page(self) -> torch.Tensor:
933935
self.layer_num,
934936
self.page_size,
935937
1,
936-
self.kv_lora_rank + self.qk_rope_head_dim,
938+
self.kv_cache_dim,
937939
),
938940
dtype=self.dtype,
939941
device=self.device,
@@ -946,14 +948,14 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
946948
self.layer_num,
947949
self.page_size,
948950
1,
949-
self.kv_lora_rank + self.qk_rope_head_dim,
951+
self.kv_cache_dim,
950952
)
951953
elif self.layout == "page_first":
952954
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
953955
self.page_size,
954956
self.layer_num,
955957
1,
956-
self.kv_lora_rank + self.qk_rope_head_dim,
958+
self.kv_cache_dim,
957959
)
958960
elif self.layout == "page_first_direct":
959961
real_index = index // self.page_size
@@ -962,7 +964,7 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
962964
self.layer_num,
963965
self.page_size,
964966
1,
965-
self.kv_lora_rank + self.qk_rope_head_dim,
967+
self.kv_cache_dim,
966968
)
967969
else:
968970
raise ValueError(f"Unsupported layout: {self.layout}")
@@ -980,38 +982,199 @@ def get_page_buffer_meta(self, indices):
980982
for layer_id in range(self.layer_num):
981983
k_ptr = (
982984
kv_buffer_data_ptr
983-
+ indices[index]
984-
* (self.kv_lora_rank + self.qk_rope_head_dim)
985-
* self.dtype.itemsize
986-
+ layer_id
987-
* self.size
988-
* (self.kv_lora_rank + self.qk_rope_head_dim)
989-
* self.dtype.itemsize
985+
+ indices[index] * self.kv_cache_dim * self.dtype.itemsize
986+
+ layer_id * self.size * self.kv_cache_dim * self.dtype.itemsize
990987
)
991988
ptr_list.append(k_ptr)
992-
element_size = (
993-
self.dtype.itemsize
994-
* self.page_size
995-
* (self.kv_lora_rank + self.qk_rope_head_dim)
996-
)
989+
element_size = self.dtype.itemsize * self.page_size * self.kv_cache_dim
997990
element_size_list = [element_size] * len(ptr_list)
998991
elif self.layout in ["page_first", "page_first_direct"]:
999992
for index in range(0, len(indices), self.page_size):
1000993
k_ptr = (
1001994
kv_buffer_data_ptr
1002995
+ indices[index]
1003996
* self.layer_num
1004-
* (self.kv_lora_rank + self.qk_rope_head_dim)
997+
* self.kv_cache_dim
1005998
* self.dtype.itemsize
1006999
)
10071000
ptr_list.append(k_ptr)
10081001
element_size = (
10091002
self.layer_num
10101003
* self.dtype.itemsize
10111004
* self.page_size
1012-
* (self.kv_lora_rank + self.qk_rope_head_dim)
1005+
* self.kv_cache_dim
10131006
)
10141007
element_size_list = [element_size] * len(ptr_list)
10151008
else:
10161009
raise ValueError(f"Unsupported layout: {self.layout}")
10171010
return ptr_list, element_size_list
1011+
1012+
1013+
class NSATokenToKVPoolHost(MLATokenToKVPoolHost):
1014+
device_pool: NSATokenToKVPool
1015+
1016+
def __init__(
1017+
self,
1018+
device_pool: NSATokenToKVPool,
1019+
host_to_device_ratio: float,
1020+
host_size: int,
1021+
page_size: int,
1022+
layout: str,
1023+
pin_memory: bool = True,
1024+
device: str = "cpu",
1025+
allocator_type: str = "default",
1026+
):
1027+
# Initialize indexer metadata before HostKVCache.__init__ calls get_size_per_token.
1028+
self.index_head_dim = device_pool.index_head_dim
1029+
self.indexer_quant_block_size = device_pool.quant_block_size
1030+
self.indexer_dtype = NSATokenToKVPool.index_k_with_scale_buffer_dtype
1031+
self.indexer_size_per_token = (
1032+
self.index_head_dim
1033+
+ self.index_head_dim // self.indexer_quant_block_size * 4
1034+
)
1035+
super().__init__(
1036+
device_pool,
1037+
host_to_device_ratio,
1038+
host_size,
1039+
page_size,
1040+
layout,
1041+
pin_memory,
1042+
device,
1043+
allocator_type,
1044+
override_kv_cache_dim=device_pool.kv_cache_dim,
1045+
)
1046+
self.indexer_page_stride_size = (
1047+
self.indexer_size_per_token * self.page_size * self.indexer_dtype.itemsize
1048+
)
1049+
self.indexer_page_num = (self.size + self.page_size + 1) // self.page_size
1050+
self._init_indexer_buffers()
1051+
logger.info(
1052+
f"NSATokenToKVPoolHost initialized with indexer page stride size: {self.indexer_page_stride_size}, page num: {self.indexer_page_num}"
1053+
)
1054+
1055+
def get_size_per_token(self):
1056+
base = super().get_size_per_token()
1057+
return (
1058+
base
1059+
+ self.indexer_size_per_token * self.layer_num * self.indexer_dtype.itemsize
1060+
)
1061+
1062+
def _init_indexer_buffers(self):
1063+
alloc_func = ALLOC_MEMORY_FUNCS[self.device_pool.device]
1064+
self.index_k_with_scale_buffer = [
1065+
alloc_func(
1066+
(self.indexer_page_num, self.indexer_page_stride_size),
1067+
dtype=self.indexer_dtype,
1068+
device=self.device,
1069+
pin_memory=self.pin_memory,
1070+
allocator=self.allocator,
1071+
)
1072+
for _ in range(self.layer_num)
1073+
]
1074+
self.index_k_data_refs = [
1075+
self.index_k_with_scale_buffer[i] for i in range(self.layer_num)
1076+
]
1077+
self.index_k_data_ptrs = torch.tensor(
1078+
[x.data_ptr() for x in self.index_k_data_refs],
1079+
dtype=torch.uint64,
1080+
device=self.device_pool.device,
1081+
)
1082+
self.index_k_device_ptrs = torch.tensor(
1083+
[x.data_ptr() for x in self.device_pool.index_k_with_scale_buffer],
1084+
dtype=torch.uint64,
1085+
device=self.device_pool.device,
1086+
)
1087+
1088+
def _get_indexer_page_indices(self, host_indices, device_indices):
1089+
if host_indices.numel() == 0:
1090+
return host_indices, device_indices
1091+
if host_indices.numel() % self.page_size != 0:
1092+
raise ValueError(
1093+
"Index buffer transfer expects page-aligned indices for NSA."
1094+
)
1095+
host_page_indices = (
1096+
host_indices.reshape(-1, self.page_size)[:, 0] // self.page_size
1097+
)
1098+
device_page_indices = (
1099+
device_indices.reshape(-1, self.page_size)[:, 0] // self.page_size
1100+
)
1101+
return host_page_indices, device_page_indices
1102+
1103+
def _load_indexer_to_device_per_layer(
1104+
self, device_pool, host_indices, device_indices, layer_id, io_backend
1105+
):
1106+
host_page_indices, device_page_indices = self._get_indexer_page_indices(
1107+
host_indices, device_indices
1108+
)
1109+
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
1110+
if use_kernel:
1111+
transfer_kv_per_layer_mla(
1112+
src=self.index_k_with_scale_buffer[layer_id],
1113+
dst=device_pool.index_k_with_scale_buffer[layer_id],
1114+
src_indices=host_page_indices,
1115+
dst_indices=device_page_indices,
1116+
item_size=self.indexer_page_stride_size,
1117+
)
1118+
else:
1119+
transfer_kv_direct(
1120+
src_layers=[self.index_k_with_scale_buffer[layer_id]],
1121+
dst_layers=[device_pool.index_k_with_scale_buffer[layer_id]],
1122+
src_indices=host_page_indices,
1123+
dst_indices=device_page_indices,
1124+
page_size=1,
1125+
)
1126+
1127+
if layer_id == 0:
1128+
logger.info(
1129+
f"NSATokenToKVPoolHost loaded indexer to device for layer {layer_id}, use_kernel: {use_kernel}, host_page_indices: {host_page_indices}, device_page_indices: {device_page_indices}"
1130+
)
1131+
1132+
def _backup_indexer_from_device_all_layer(
1133+
self, device_pool, host_indices, device_indices, io_backend
1134+
):
1135+
host_page_indices, device_page_indices = self._get_indexer_page_indices(
1136+
host_indices, device_indices
1137+
)
1138+
use_kernel = io_backend == "kernel" and self.indexer_page_stride_size % 8 == 0
1139+
if use_kernel:
1140+
transfer_kv_all_layer_mla(
1141+
src_layers=self.index_k_device_ptrs,
1142+
dst_layers=self.index_k_data_ptrs,
1143+
src_indices=device_page_indices,
1144+
dst_indices=host_page_indices,
1145+
item_size=self.indexer_page_stride_size,
1146+
num_layers=self.layer_num,
1147+
)
1148+
else:
1149+
transfer_kv_direct(
1150+
src_layers=device_pool.index_k_with_scale_buffer,
1151+
dst_layers=self.index_k_with_scale_buffer,
1152+
src_indices=device_page_indices,
1153+
dst_indices=host_page_indices,
1154+
page_size=1,
1155+
)
1156+
1157+
def load_to_device_per_layer(
1158+
self,
1159+
device_pool,
1160+
host_indices,
1161+
device_indices,
1162+
layer_id,
1163+
io_backend,
1164+
):
1165+
super().load_to_device_per_layer(
1166+
device_pool, host_indices, device_indices, layer_id, io_backend
1167+
)
1168+
self._load_indexer_to_device_per_layer(
1169+
device_pool, host_indices, device_indices, layer_id, io_backend
1170+
)
1171+
1172+
def backup_from_device_all_layer(
1173+
self, device_pool, host_indices, device_indices, io_backend
1174+
):
1175+
super().backup_from_device_all_layer(
1176+
device_pool, host_indices, device_indices, io_backend
1177+
)
1178+
self._backup_indexer_from_device_all_layer(
1179+
device_pool, host_indices, device_indices, io_backend
1180+
)

0 commit comments

Comments
 (0)