Skip to content

Commit 91e2022

Browse files
committed
Fix mamba slot leak when scheduling fails with radix cache (#15840)
When add_one_req fails after init_next_round_input allocates a mamba slot via COW (copy-on-write) during match_prefix, the slot was not released, causing memory leak. This fix releases the mamba slot when scheduling fails.
1 parent f3d73b0 commit 91e2022

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,12 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
19821982
)
19831983

19841984
if res != AddReqResult.CONTINUE:
1985+
# Release mamba slot allocated via COW if scheduling fails
1986+
if self.is_hybrid_ssm and req.mamba_pool_idx is not None:
1987+
self.req_to_token_pool.mamba_pool.free(
1988+
req.mamba_pool_idx.unsqueeze(-1)
1989+
)
1990+
req.mamba_pool_idx = None
19851991
if res == AddReqResult.NO_TOKEN:
19861992
if self.enable_hierarchical_cache:
19871993
# Set batch_is_full after making sure there are requests that can be served

test/srt/test_mamba_unittest.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,110 @@ def make_dummy_req():
336336
== mamba_pool.mamba_cache.temporal[:, last_node.mamba_value]
337337
)
338338

339+
def test_mamba_slot_release_after_match_prefix_cow(self):
340+
num_layers, global_interval = 48, 4
341+
full_attention_layer_ids = list(
342+
range(global_interval - 1, num_layers, global_interval)
343+
)
344+
mamba_layers = [
345+
i for i in range(num_layers) if i not in full_attention_layer_ids
346+
]
347+
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16"
348+
349+
mamba2_cache_params = Mamba2CacheParams(
350+
shape=Mamba2StateShape.create(
351+
tp_world_size=1,
352+
intermediate_size=4096,
353+
n_groups=16,
354+
num_heads=32,
355+
head_dim=128,
356+
state_size=128,
357+
conv_kernel=4,
358+
),
359+
layers=mamba_layers,
360+
)
361+
req_to_token_pool = HybridReqToTokenPool(
362+
size=10,
363+
mamba_size=20,
364+
mamba_spec_state_size=10,
365+
max_context_len=128,
366+
device="cuda",
367+
enable_memory_saver=False,
368+
cache_params=mamba2_cache_params,
369+
enable_mamba_extra_buffer=False,
370+
speculative_num_draft_tokens=3,
371+
)
372+
pool = HybridLinearKVPool(
373+
size=128,
374+
dtype=torch.bfloat16,
375+
page_size=1,
376+
head_num=2,
377+
head_dim=256,
378+
full_attention_layer_ids=full_attention_layer_ids,
379+
enable_kvcache_transpose=False,
380+
device="cuda",
381+
enable_memory_saver=False,
382+
mamba_pool=req_to_token_pool.mamba_pool,
383+
)
384+
allocator = TokenToKVPoolAllocator(
385+
size=128,
386+
dtype=torch.bfloat16,
387+
device="cuda",
388+
kvcache=pool,
389+
need_sort=False,
390+
)
391+
tree = MambaRadixCache(
392+
params=CacheInitParams(
393+
req_to_token_pool=req_to_token_pool,
394+
token_to_kv_pool_allocator=allocator,
395+
page_size=1,
396+
disable=False,
397+
)
398+
)
399+
mamba_pool = req_to_token_pool.mamba_pool
400+
401+
# Insert req1 to create cached mamba state
402+
sampling_params = SamplingParams(temperature=0, max_new_tokens=1)
403+
req1 = Req(
404+
rid=0,
405+
origin_input_text="",
406+
origin_input_ids=[],
407+
sampling_params=sampling_params,
408+
)
409+
req_to_token_pool.alloc(1, reqs=[req1])
410+
token_ids = [1, 2, 3, 4, 5]
411+
tree.insert(
412+
RadixKey(token_ids),
413+
allocator.alloc(len(token_ids)),
414+
req1.mamba_pool_idx.unsqueeze(0),
415+
)
416+
417+
initial_available = mamba_pool.available_size()
418+
419+
# req2 matches prefix with COW - this allocates a new mamba slot
420+
req2 = Req(
421+
rid=1,
422+
origin_input_text="",
423+
origin_input_ids=[],
424+
sampling_params=sampling_params,
425+
)
426+
tree.match_prefix(RadixKey(token_ids), req=req2, cow_mamba=True)
427+
428+
# Verify COW allocated a mamba slot
429+
assert req2.mamba_pool_idx is not None, "COW should allocate mamba slot"
430+
assert (
431+
mamba_pool.available_size() < initial_available
432+
), "Pool size should decrease"
433+
434+
# Simulate scheduling failure cleanup
435+
mamba_pool.free(req2.mamba_pool_idx.unsqueeze(-1))
436+
req2.mamba_pool_idx = None
437+
438+
# Verify slot is released
439+
assert (
440+
mamba_pool.available_size() == initial_available
441+
), "Slot should be released"
442+
339443

340444
if __name__ == "__main__":
341445
unittest.main()

0 commit comments

Comments
 (0)