Skip to content

Commit c4fc22c

Browse files
cctryJohn Doe
authored andcommitted
[Fix] data race in req_to_token pool (sgl-project#17850)
1 parent 44351ca commit c4fc22c

File tree

13 files changed

+109
-115
lines changed

13 files changed

+109
-115
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from collections import deque
2626
from dataclasses import dataclass
2727
from http import HTTPStatus
28-
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
2929

3030
import torch
3131
from torch.distributed import ProcessGroup
@@ -116,19 +116,31 @@ def write(self, indices, values):
116116
def available_size(self):
117117
return len(self.free_slots)
118118

119-
def alloc(self, need_size: int) -> List[int]:
119+
def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
120+
chunked = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
121+
assert (
122+
len(chunked) <= 1
123+
), "only one chunked request may reuse req_pool_idx in a batch"
124+
assert all(
125+
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in chunked
126+
), "request has req_pool_idx but is not chunked"
127+
128+
need_size = len(reqs) - len(chunked)
120129
if need_size > len(self.free_slots):
121130
return None
122-
123131
select_index = self.free_slots[:need_size]
124132
self.free_slots = self.free_slots[need_size:]
125-
return select_index
126-
127-
def free(self, free_index: Union[int, List[int]]):
128-
if isinstance(free_index, (int,)):
129-
self.free_slots.append(free_index)
130-
else:
131-
self.free_slots.extend(free_index)
133+
offset = 0
134+
for r in reqs:
135+
if r.req_pool_idx is None:
136+
r.req_pool_idx = select_index[offset]
137+
offset += 1
138+
return [r.req_pool_idx for r in reqs]
139+
140+
def free(self, req: "Req"):
141+
assert req.req_pool_idx is not None, "request must have req_pool_idx"
142+
self.free_slots.append(req.req_pool_idx)
143+
req.req_pool_idx = None
132144

133145
def clear(self):
134146
self.free_slots = list(range(self.size + self.pre_alloc_size))
@@ -652,17 +664,12 @@ def _allocatable_tokens(
652664

653665
def _pre_alloc(self, req: Req) -> torch.Tensor:
654666
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
655-
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
656-
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
657-
else:
658-
req_pool_indices = self.req_to_token_pool.alloc(1)
667+
req_pool_indices = self.req_to_token_pool.alloc([req])
659668

660669
assert (
661670
req_pool_indices is not None
662671
), "req_pool_indices is full! There is a bug in memory estimation."
663672

664-
req.req_pool_idx = req_pool_indices[0]
665-
666673
# Alloc all tokens for the prebuilt req (except for the reserved input token for decoding)
667674
fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
668675
req.kv_allocated_len = fill_len

python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _release_finished_req(self, req: Req, prefill_offloaded_len: int):
191191

192192
# Free the incremental part of the request
193193
self.token_to_kv_pool_allocator.free(kv_indices)
194-
self.req_to_token_pool.free(req.req_pool_idx)
194+
self.req_to_token_pool.free(req)
195195
self.tree_cache.protected_size_ -= len(req.prefix_indices)
196196

197197
def _check_backup_progress(self, finish_count):

python/sglang/srt/disaggregation/prefill.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -632,13 +632,6 @@ def process_prefill_chunk(self: Scheduler) -> None:
632632
)
633633
else:
634634
self.send_kv_chunk(self.chunked_req)
635-
# chunked request keeps its rid but will get a new req_pool_idx
636-
if self.tp_worker.model_runner.mambaish_config is not None:
637-
self.req_to_token_pool.free(
638-
self.chunked_req.req_pool_idx, free_mamba_cache=False
639-
)
640-
else:
641-
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
642635
self.running_batch.batch_is_full = False
643636

644637
if self.last_batch and self.last_batch.forward_mode.is_extend():

python/sglang/srt/managers/scheduler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,11 +1792,6 @@ def handle_batch_embedding_request(
17921792

17931793
def stash_chunked_request(self, req: Req):
17941794
self.tree_cache.cache_unfinished_req(req, chunked=True)
1795-
# Chunked request keeps its rid but will get a new req_pool_idx
1796-
if self.tp_worker.model_runner.mambaish_config is not None:
1797-
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=False)
1798-
else:
1799-
self.req_to_token_pool.free(req.req_pool_idx)
18001795

18011796
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
18021797
self._abort_on_queued_timeout()

python/sglang/srt/managers/scheduler_pp_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def profile_and_init_predictor(self: Scheduler):
644644
req.req_pool_idx, : len(req.fill_ids)
645645
]
646646
self.token_to_kv_pool_allocator.free(kv_indices)
647-
self.req_to_token_pool.free(req.req_pool_idx)
647+
self.req_to_token_pool.free(req)
648648

649649
logger.info(
650650
f"[PP Dynamic Chunk] [PP0] Profiled {len(seq_lens)} samples: "

python/sglang/srt/mem_cache/chunk_cache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def cache_finished_req(self, req: Req, is_insert: bool = True):
6868
kv_indices = self.req_to_token_pool.req_to_token[
6969
req.req_pool_idx, :kv_committed_len
7070
]
71-
self.req_to_token_pool.free(req.req_pool_idx)
7271
self.token_to_kv_pool_allocator.free(kv_indices)
7372

7473
def cache_unfinished_req(self, req: Req, chunked=False):

python/sglang/srt/mem_cache/common.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,11 @@ def alloc_paged_token_slots_extend(
296296

297297
def alloc_req_slots(
298298
req_to_token_pool: ReqToTokenPool,
299-
num_reqs: int,
300-
reqs: list[Req] | None,
299+
reqs: list[Req],
301300
tree_cache: BasePrefixCache | None,
302301
) -> list[int]:
303302
"""Allocate request slots from the pool."""
303+
num_reqs = len(reqs)
304304
if isinstance(req_to_token_pool, HybridReqToTokenPool):
305305
mamba_available_size = req_to_token_pool.mamba_pool.available_size()
306306
factor = (
@@ -313,9 +313,7 @@ def alloc_req_slots(
313313
if tree_cache is not None and tree_cache.supports_mamba():
314314
mamba_num = max(0, mamba_state_needed - mamba_available_size)
315315
tree_cache.evict(EvictParams(num_tokens=0, mamba_num=mamba_num))
316-
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
317-
else:
318-
req_pool_indices = req_to_token_pool.alloc(num_reqs)
316+
req_pool_indices = req_to_token_pool.alloc(reqs)
319317

320318
if req_pool_indices is None:
321319
raise RuntimeError(
@@ -341,7 +339,6 @@ def alloc_for_extend(
341339
# free out-of-window swa tokens
342340
batch.maybe_evict_swa()
343341

344-
bs = len(batch.reqs)
345342
prefix_tensors = [r.prefix_indices for r in batch.reqs]
346343

347344
# Create tensors for allocation
@@ -352,7 +349,7 @@ def alloc_for_extend(
352349

353350
# Allocate req slots
354351
req_pool_indices = alloc_req_slots(
355-
batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
352+
batch.req_to_token_pool, batch.reqs, batch.tree_cache
356353
)
357354
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
358355
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
@@ -466,15 +463,21 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
466463

467464

468465
def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = True):
469-
tree_cache.cache_finished_req(req, is_insert=is_insert)
470-
471466
# MambaRadixCache may alloc mamba state before alloc KV cache
472467
if req.req_pool_idx is None:
473468
assert (
474469
tree_cache.supports_mamba()
475-
), "Only MambaRadixCache can handle abort with prefix cache hit before alloc"
470+
), "Only MambaRadixCache allow freeing before alloc"
471+
# TODO (csy, hanming): clean up this early allocation logic
472+
if req.mamba_pool_idx is not None:
473+
tree_cache.req_to_token_pool.mamba_pool.free(
474+
req.mamba_pool_idx.unsqueeze(-1)
475+
)
476+
req.mamba_pool_idx = None
476477
return
477478

479+
tree_cache.cache_finished_req(req, is_insert=is_insert)
480+
478481
start_p, end_p = req.pop_overallocated_kv_cache()
479482

480483
global_server_args = get_global_server_args()
@@ -489,13 +492,20 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr
489492
if page_size > 1:
490493
start_p = ceil_align(start_p, page_size)
491494

492-
if start_p >= end_p:
493-
return
494-
495-
indices_to_free = tree_cache.req_to_token_pool.req_to_token[req.req_pool_idx][
496-
start_p:end_p
497-
]
498-
tree_cache.token_to_kv_pool_allocator.free(indices_to_free)
495+
if start_p < end_p:
496+
indices_to_free = tree_cache.req_to_token_pool.req_to_token[req.req_pool_idx][
497+
start_p:end_p
498+
]
499+
tree_cache.token_to_kv_pool_allocator.free(indices_to_free)
500+
# If the prefix cache doesn't manage mamba states, we must free them here.
501+
if isinstance(tree_cache.req_to_token_pool, HybridReqToTokenPool) and (
502+
not tree_cache.supports_mamba()
503+
):
504+
assert (
505+
req.mamba_pool_idx is not None
506+
), "mamba state is freed while the tree cache does not manage mamba states"
507+
tree_cache.req_to_token_pool.free_mamba_cache(req)
508+
tree_cache.req_to_token_pool.free(req)
499509

500510

501511
def available_and_evictable_str(tree_cache) -> str:

python/sglang/srt/mem_cache/mamba_radix_cache.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -499,21 +499,14 @@ def insert(self, params: InsertParams) -> InsertResult:
499499

500500
def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
501501
"""Cache request when it finishes."""
502-
# for abort with prefix cache hit and before alloc is called
503-
if req.req_pool_idx is None:
504-
if req.mamba_pool_idx is not None:
505-
self.req_to_token_pool.mamba_pool.free(req.mamba_pool_idx.unsqueeze(-1))
506-
req.mamba_pool_idx = None
507-
return
508-
509502
kv_committed_len = req.pop_committed_kv_cache()
510503

511504
if self.disable:
512505
kv_indices = self.req_to_token_pool.req_to_token[
513506
req.req_pool_idx, :kv_committed_len
514507
]
515508
self.token_to_kv_pool_allocator.free(kv_indices)
516-
self.req_to_token_pool.free(req.req_pool_idx)
509+
self.req_to_token_pool.free_mamba_cache(req)
517510
return
518511

519512
token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
@@ -588,11 +581,11 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
588581

589582
free_mamba_cache = True if self.enable_mamba_extra_buffer else mamba_exist
590583

591-
self.req_to_token_pool.free(
592-
req.req_pool_idx,
593-
free_mamba_cache=free_mamba_cache,
594-
mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
595-
)
584+
if free_mamba_cache:
585+
self.req_to_token_pool.free_mamba_cache(
586+
req,
587+
mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
588+
)
596589

597590
self.dec_lock_ref(req.last_node)
598591

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def __init__(
133133
device: str,
134134
enable_memory_saver: bool,
135135
):
136-
137136
memory_saver_adapter = TorchMemorySaverAdapter.create(
138137
enable=enable_memory_saver
139138
)
@@ -145,7 +144,6 @@ def __init__(
145144
self.req_to_token = torch.zeros(
146145
(size, max_context_len), dtype=torch.int32, device=device
147146
)
148-
149147
self.free_slots = list(range(size))
150148

151149
def write(self, indices, values):
@@ -154,20 +152,32 @@ def write(self, indices, values):
154152
def available_size(self):
155153
return len(self.free_slots)
156154

157-
def alloc(self, need_size: int) -> List[int]:
155+
def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
156+
chunked = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
157+
if not any(r.is_dllm() for r in reqs):
158+
assert (
159+
len(chunked) <= 1
160+
), "only one chunked request may reuse req_pool_idx in a batch"
161+
assert all(
162+
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in chunked
163+
), "request has req_pool_idx but is not chunked"
164+
165+
need_size = len(reqs) - len(chunked)
158166
if need_size > len(self.free_slots):
159167
return None
160-
161168
select_index = self.free_slots[:need_size]
162169
self.free_slots = self.free_slots[need_size:]
163-
164-
return select_index
165-
166-
def free(self, free_index: Union[int, List[int]]):
167-
if isinstance(free_index, (int,)):
168-
self.free_slots.append(free_index)
169-
else:
170-
self.free_slots.extend(free_index)
170+
offset = 0
171+
for r in reqs:
172+
if r.req_pool_idx is None:
173+
r.req_pool_idx = select_index[offset]
174+
offset += 1
175+
return [r.req_pool_idx for r in reqs]
176+
177+
def free(self, req: Req):
178+
assert req.req_pool_idx is not None, "request must have req_pool_idx"
179+
self.free_slots.append(req.req_pool_idx)
180+
req.req_pool_idx = None
171181

172182
def clear(self):
173183
self.free_slots = list(range(self.size))
@@ -488,10 +498,9 @@ def _init_mamba_pool(
488498

489499
# For chunk prefill req, we do not need to allocate mamba cache,
490500
# We could use allocated mamba cache instead.
491-
def alloc(self, need_size: int, reqs: Optional[List["Req"]]) -> Optional[List[int]]:
492-
assert reqs is not None
493-
select_index = super().alloc(need_size)
494-
if select_index == None:
501+
def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
502+
select_index = super().alloc(reqs)
503+
if select_index is None:
495504
return None
496505

497506
mamba_index = []
@@ -556,37 +565,29 @@ def get_mamba_ping_pong_other_idx(self, mamba_next_track_idx: int) -> int:
556565
else:
557566
return mamba_next_track_idx
558567

559-
# For chunk prefill, we can not free mamba cache, we need use it in the future
560-
def free(
561-
self,
562-
free_index: Union[int, List[int]],
563-
free_mamba_cache: bool = True,
564-
mamba_ping_pong_track_buffer_to_keep: Optional[int] = None,
568+
def free_mamba_cache(
569+
self, req: "Req", mamba_ping_pong_track_buffer_to_keep: Optional[int] = None
565570
):
566-
if isinstance(free_index, (int,)):
567-
free_index = [free_index]
568-
super().free(free_index)
569-
if free_mamba_cache:
570-
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
571-
self.mamba_pool.free(mamba_index)
571+
mamba_index = req.mamba_pool_idx
572+
assert mamba_index is not None, "double free? mamba_index is None"
573+
self.mamba_pool.free(mamba_index.unsqueeze(0))
574+
req.mamba_pool_idx = None
572575

573-
if self.enable_mamba_extra_buffer:
576+
if self.enable_mamba_extra_buffer:
577+
mamba_ping_pong_track_buffer_to_free = (
578+
self.req_index_to_mamba_ping_pong_track_buffer_mapping[req.req_pool_idx]
579+
)
580+
if mamba_ping_pong_track_buffer_to_keep is not None:
581+
assert mamba_ping_pong_track_buffer_to_keep in [
582+
0,
583+
1,
584+
], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}"
585+
idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size))
586+
idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep)
574587
mamba_ping_pong_track_buffer_to_free = (
575-
self.req_index_to_mamba_ping_pong_track_buffer_mapping[
576-
free_index
577-
].squeeze(0)
588+
mamba_ping_pong_track_buffer_to_free[idx_to_free]
578589
)
579-
if mamba_ping_pong_track_buffer_to_keep is not None:
580-
assert mamba_ping_pong_track_buffer_to_keep in [
581-
0,
582-
1,
583-
], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}"
584-
idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size))
585-
idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep)
586-
mamba_ping_pong_track_buffer_to_free = (
587-
mamba_ping_pong_track_buffer_to_free[idx_to_free]
588-
)
589-
self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free)
590+
self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free)
590591

591592
def clear(self):
592593
logger.info("Reset HybridReqToTokenPool")

0 commit comments

Comments
 (0)