@@ -336,6 +336,110 @@ def make_dummy_req():
336336 == mamba_pool .mamba_cache .temporal [:, last_node .mamba_value ]
337337 )
338338
339+ def test_mamba_slot_release_after_match_prefix_cow (self ):
340+ num_layers , global_interval = 48 , 4
341+ full_attention_layer_ids = list (
342+ range (global_interval - 1 , num_layers , global_interval )
343+ )
344+ mamba_layers = [
345+ i for i in range (num_layers ) if i not in full_attention_layer_ids
346+ ]
347+ os .environ ["SGLANG_MAMBA_SSM_DTYPE" ] = "bfloat16"
348+
349+ mamba2_cache_params = Mamba2CacheParams (
350+ shape = Mamba2StateShape .create (
351+ tp_world_size = 1 ,
352+ intermediate_size = 4096 ,
353+ n_groups = 16 ,
354+ num_heads = 32 ,
355+ head_dim = 128 ,
356+ state_size = 128 ,
357+ conv_kernel = 4 ,
358+ ),
359+ layers = mamba_layers ,
360+ )
361+ req_to_token_pool = HybridReqToTokenPool (
362+ size = 10 ,
363+ mamba_size = 20 ,
364+ mamba_spec_state_size = 10 ,
365+ max_context_len = 128 ,
366+ device = "cuda" ,
367+ enable_memory_saver = False ,
368+ cache_params = mamba2_cache_params ,
369+ enable_mamba_extra_buffer = False ,
370+ speculative_num_draft_tokens = 3 ,
371+ )
372+ pool = HybridLinearKVPool (
373+ size = 128 ,
374+ dtype = torch .bfloat16 ,
375+ page_size = 1 ,
376+ head_num = 2 ,
377+ head_dim = 256 ,
378+ full_attention_layer_ids = full_attention_layer_ids ,
379+ enable_kvcache_transpose = False ,
380+ device = "cuda" ,
381+ enable_memory_saver = False ,
382+ mamba_pool = req_to_token_pool .mamba_pool ,
383+ )
384+ allocator = TokenToKVPoolAllocator (
385+ size = 128 ,
386+ dtype = torch .bfloat16 ,
387+ device = "cuda" ,
388+ kvcache = pool ,
389+ need_sort = False ,
390+ )
391+ tree = MambaRadixCache (
392+ params = CacheInitParams (
393+ req_to_token_pool = req_to_token_pool ,
394+ token_to_kv_pool_allocator = allocator ,
395+ page_size = 1 ,
396+ disable = False ,
397+ )
398+ )
399+ mamba_pool = req_to_token_pool .mamba_pool
400+
401+ # Insert req1 to create cached mamba state
402+ sampling_params = SamplingParams (temperature = 0 , max_new_tokens = 1 )
403+ req1 = Req (
404+ rid = 0 ,
405+ origin_input_text = "" ,
406+ origin_input_ids = [],
407+ sampling_params = sampling_params ,
408+ )
409+ req_to_token_pool .alloc (1 , reqs = [req1 ])
410+ token_ids = [1 , 2 , 3 , 4 , 5 ]
411+ tree .insert (
412+ RadixKey (token_ids ),
413+ allocator .alloc (len (token_ids )),
414+ req1 .mamba_pool_idx .unsqueeze (0 ),
415+ )
416+
417+ initial_available = mamba_pool .available_size ()
418+
419+ # req2 matches prefix with COW - this allocates a new mamba slot
420+ req2 = Req (
421+ rid = 1 ,
422+ origin_input_text = "" ,
423+ origin_input_ids = [],
424+ sampling_params = sampling_params ,
425+ )
426+ tree .match_prefix (RadixKey (token_ids ), req = req2 , cow_mamba = True )
427+
428+ # Verify COW allocated a mamba slot
429+ assert req2 .mamba_pool_idx is not None , "COW should allocate mamba slot"
430+ assert (
431+ mamba_pool .available_size () < initial_available
432+ ), "Pool size should decrease"
433+
434+ # Simulate scheduling failure cleanup
435+ mamba_pool .free (req2 .mamba_pool_idx .unsqueeze (- 1 ))
436+ req2 .mamba_pool_idx = None
437+
438+ # Verify slot is released
439+ assert (
440+ mamba_pool .available_size () == initial_available
441+ ), "Slot should be released"
442+
339443
340444if __name__ == "__main__" :
341445 unittest .main ()
0 commit comments