@@ -133,7 +133,6 @@ def __init__(
133133 device : str ,
134134 enable_memory_saver : bool ,
135135 ):
136-
137136 memory_saver_adapter = TorchMemorySaverAdapter .create (
138137 enable = enable_memory_saver
139138 )
@@ -145,7 +144,6 @@ def __init__(
145144 self .req_to_token = torch .zeros (
146145 (size , max_context_len ), dtype = torch .int32 , device = device
147146 )
148-
149147 self .free_slots = list (range (size ))
150148
151149 def write (self , indices , values ):
@@ -154,20 +152,32 @@ def write(self, indices, values):
154152 def available_size (self ):
155153 return len (self .free_slots )
156154
157- def alloc (self , need_size : int ) -> List [int ]:
155+ def alloc (self , reqs : list [Req ]) -> Optional [List [int ]]:
156+ chunked = [i for i , r in enumerate (reqs ) if r .req_pool_idx is not None ]
157+ if not any (r .is_dllm () for r in reqs ):
158+ assert (
159+ len (chunked ) <= 1
160+ ), "only one chunked request may reuse req_pool_idx in a batch"
161+ assert all (
162+ reqs [i ].is_chunked > 0 or reqs [i ].kv_committed_len > 0 for i in chunked
163+ ), "request has req_pool_idx but is not chunked"
164+
165+ need_size = len (reqs ) - len (chunked )
158166 if need_size > len (self .free_slots ):
159167 return None
160-
161168 select_index = self .free_slots [:need_size ]
162169 self .free_slots = self .free_slots [need_size :]
163-
164- return select_index
165-
166- def free (self , free_index : Union [int , List [int ]]):
167- if isinstance (free_index , (int ,)):
168- self .free_slots .append (free_index )
169- else :
170- self .free_slots .extend (free_index )
170+ offset = 0
171+ for r in reqs :
172+ if r .req_pool_idx is None :
173+ r .req_pool_idx = select_index [offset ]
174+ offset += 1
175+ return [r .req_pool_idx for r in reqs ]
176+
177+ def free (self , req : Req ):
178+ assert req .req_pool_idx is not None , "request must have req_pool_idx"
179+ self .free_slots .append (req .req_pool_idx )
180+ req .req_pool_idx = None
171181
172182 def clear (self ):
173183 self .free_slots = list (range (self .size ))
@@ -488,10 +498,9 @@ def _init_mamba_pool(
488498
489499 # For chunk prefill req, we do not need to allocate mamba cache,
490500 # We could use allocated mamba cache instead.
491- def alloc (self , need_size : int , reqs : Optional [List ["Req" ]]) -> Optional [List [int ]]:
492- assert reqs is not None
493- select_index = super ().alloc (need_size )
494- if select_index == None :
501+ def alloc (self , reqs : List ["Req" ]) -> Optional [List [int ]]:
502+ select_index = super ().alloc (reqs )
503+ if select_index is None :
495504 return None
496505
497506 mamba_index = []
@@ -556,37 +565,29 @@ def get_mamba_ping_pong_other_idx(self, mamba_next_track_idx: int) -> int:
556565 else :
557566 return mamba_next_track_idx
558567
559- # For chunk prefill, we can not free mamba cache, we need use it in the future
560- def free (
561- self ,
562- free_index : Union [int , List [int ]],
563- free_mamba_cache : bool = True ,
564- mamba_ping_pong_track_buffer_to_keep : Optional [int ] = None ,
568+ def free_mamba_cache (
569+ self , req : "Req" , mamba_ping_pong_track_buffer_to_keep : Optional [int ] = None
565570 ):
566- if isinstance (free_index , (int ,)):
567- free_index = [free_index ]
568- super ().free (free_index )
569- if free_mamba_cache :
570- mamba_index = self .req_index_to_mamba_index_mapping [free_index ]
571- self .mamba_pool .free (mamba_index )
571+ mamba_index = req .mamba_pool_idx
572+ assert mamba_index is not None , "double free? mamba_index is None"
573+ self .mamba_pool .free (mamba_index .unsqueeze (0 ))
574+ req .mamba_pool_idx = None
572575
573- if self .enable_mamba_extra_buffer :
576+ if self .enable_mamba_extra_buffer :
577+ mamba_ping_pong_track_buffer_to_free = (
578+ self .req_index_to_mamba_ping_pong_track_buffer_mapping [req .req_pool_idx ]
579+ )
580+ if mamba_ping_pong_track_buffer_to_keep is not None :
581+ assert mamba_ping_pong_track_buffer_to_keep in [
582+ 0 ,
583+ 1 ,
584+ ], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, { mamba_ping_pong_track_buffer_to_keep = } "
585+ idx_to_free = list (range (self .mamba_ping_pong_track_buffer_size ))
586+ idx_to_free .remove (mamba_ping_pong_track_buffer_to_keep )
574587 mamba_ping_pong_track_buffer_to_free = (
575- self .req_index_to_mamba_ping_pong_track_buffer_mapping [
576- free_index
577- ].squeeze (0 )
588+ mamba_ping_pong_track_buffer_to_free [idx_to_free ]
578589 )
579- if mamba_ping_pong_track_buffer_to_keep is not None :
580- assert mamba_ping_pong_track_buffer_to_keep in [
581- 0 ,
582- 1 ,
583- ], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, { mamba_ping_pong_track_buffer_to_keep = } "
584- idx_to_free = list (range (self .mamba_ping_pong_track_buffer_size ))
585- idx_to_free .remove (mamba_ping_pong_track_buffer_to_keep )
586- mamba_ping_pong_track_buffer_to_free = (
587- mamba_ping_pong_track_buffer_to_free [idx_to_free ]
588- )
589- self .mamba_pool .free (mamba_ping_pong_track_buffer_to_free )
590+ self .mamba_pool .free (mamba_ping_pong_track_buffer_to_free )
590591
591592 def clear (self ):
592593 logger .info ("Reset HybridReqToTokenPool" )
0 commit comments