Skip to content

Commit 26f2b37

Browse files
authored
[DLLM] Basic dLLM scheduling strategy and implementation (sgl-project#17484)
Signed-off-by: Zehuan Li <lizehuan.lzh@antgroup.com>
1 parent 8da14ae commit 26f2b37

File tree

9 files changed

+461
-210
lines changed

9 files changed

+461
-210
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
from typing import TYPE_CHECKING, Optional
5+
6+
from sglang.srt.dllm.config import DllmConfig
7+
8+
if TYPE_CHECKING:
9+
from sglang.srt.managers.schedule_batch import Req
10+
11+
12+
class DllmReqPhase(str, enum.Enum):
13+
STAGING_PREFILL = "staging_prefill"
14+
STAGING_DECODE = "staging_decode"
15+
INCOMING_PREFILL = "incoming_prefill"
16+
INCOMING_DECODE = "incoming_decode"
17+
18+
19+
class ReqDllmMixin:
20+
def init_diffusion_llm(self: Req, dllm_config: DllmConfig):
21+
self.dllm_phase: Optional[DllmReqPhase] = None
22+
self.dllm_ids = []
23+
self.dllm_block_offset = 0
24+
self.dllm_config = dllm_config
25+
26+
if self.dllm_config is not None:
27+
if len(self.origin_input_ids) < self.dllm_config.block_size:
28+
self.dllm_phase = DllmReqPhase.INCOMING_DECODE
29+
else:
30+
self.dllm_phase = DllmReqPhase.INCOMING_PREFILL
31+
32+
def is_dllm(self: Req) -> bool:
33+
return self.dllm_config is not None
34+
35+
def is_dllm_prefill(self: Req) -> bool:
36+
return self.dllm_phase in [
37+
DllmReqPhase.STAGING_PREFILL,
38+
DllmReqPhase.INCOMING_PREFILL,
39+
]
40+
41+
def determine_dllm_phase(self: Req):
42+
prefix_length = len(self.prefix_indices)
43+
min_required_length = prefix_length + self.dllm_config.block_size
44+
45+
if len(self.fill_ids) < min_required_length:
46+
# still incoming stage
47+
return
48+
49+
input_block = self.fill_ids[prefix_length:min_required_length]
50+
is_prefill_phase = self.dllm_config.mask_id not in input_block
51+
52+
if is_prefill_phase:
53+
self.dllm_phase = DllmReqPhase.STAGING_PREFILL
54+
else:
55+
self.dllm_phase = DllmReqPhase.STAGING_DECODE
56+
57+
def _init_fill_ids_for_dllm(self: Req):
58+
if not self.dllm_ids:
59+
self.dllm_ids = (
60+
self.origin_input_ids
61+
+ [self.dllm_config.mask_id] * self.dllm_config.block_size
62+
)
63+
else:
64+
self.dllm_block_offset += self.dllm_config.block_size
65+
self.dllm_ids += [self.dllm_config.mask_id] * self.dllm_config.block_size
66+
67+
self.fill_ids = self.dllm_ids
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import time
5+
from typing import TYPE_CHECKING, List, Optional, Set, Union
6+
7+
from sglang.srt.dllm.config import DllmConfig
8+
from sglang.srt.dllm.mixin.req import DllmReqPhase
9+
from sglang.srt.managers.schedule_batch import Req, RequestStage, ScheduleBatch
10+
from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder
11+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
12+
13+
logger = logging.getLogger(__name__)
14+
15+
if TYPE_CHECKING:
16+
from sglang.srt.managers.scheduler import Scheduler
17+
18+
19+
class SchedulerDllmMixin:
20+
def init_diffusion_llm(self: Scheduler):
21+
self.dllm_config = (
22+
DllmConfig.from_server_args(self.server_args)
23+
if self.server_args.dllm_algorithm is not None
24+
else None
25+
)
26+
self.dllm_manager = DllmManager(dllm_config=self.dllm_config)
27+
28+
def get_new_batch_dllm(self: Scheduler) -> Optional[ScheduleBatch]:
29+
"""Generate a new batch for DLLM (Diffusion LLM) scheduling."""
30+
if self.try_preemption:
31+
self.running_batch.batch_is_full = False
32+
33+
# Early exit if batch is full or no requests available
34+
if self._should_skip_prefill():
35+
return None
36+
37+
running_bs = len(self.running_batch.reqs)
38+
self.policy.calc_priority(self.waiting_queue)
39+
40+
# Create prefill adder with resource constraints
41+
adder = self._create_dllm_prefill_adder(running_bs)
42+
43+
# Initialize DLLM manager and transfer requests
44+
self.dllm_manager.init_next_round()
45+
self._fetch_waiting_reqs()
46+
47+
# Process batches
48+
forward_mode = self._process_dllm_batches(adder)
49+
50+
can_run_list = adder.can_run_list
51+
if not can_run_list:
52+
return None
53+
54+
# Record metrics and update state
55+
self._update_metrics_and_state_for_batch(can_run_list, adder, running_bs)
56+
57+
# Create and prepare batch
58+
new_batch = self._create_dllm_batch(can_run_list, forward_mode)
59+
return new_batch
60+
61+
def _fetch_waiting_reqs(self: Scheduler):
62+
# Calculate how many requests can be added to DLLM manager
63+
max_dllm_capacity = self.server_args.max_running_requests - len(
64+
self.dllm_manager.waiting_queue
65+
)
66+
num_requests_to_add = min(max_dllm_capacity, len(self.waiting_queue))
67+
68+
if num_requests_to_add > 0:
69+
requests_to_add = self.waiting_queue[:num_requests_to_add]
70+
self.dllm_manager.add_waiting_reqs(requests_to_add)
71+
self.waiting_queue = self.waiting_queue[num_requests_to_add:]
72+
73+
def _should_skip_prefill(self: Scheduler) -> bool:
74+
"""Check if DLLM prefill should be skipped."""
75+
if (
76+
self.running_batch.batch_is_full or not self.waiting_queue
77+
) and self.dllm_manager.is_empty():
78+
return True
79+
80+
running_bs = len(self.running_batch.reqs)
81+
if (
82+
self.get_num_allocatable_reqs(running_bs) <= 0
83+
and self.dllm_manager.is_empty()
84+
and not self.try_preemption
85+
):
86+
self.running_batch.batch_is_full = True
87+
return True
88+
89+
return False
90+
91+
def _create_dllm_prefill_adder(self: Scheduler, running_bs: int) -> PrefillAdder:
92+
"""Create a prefill adder configured for DLLM scheduling."""
93+
return PrefillAdder(
94+
self.page_size,
95+
self.tree_cache,
96+
self.token_to_kv_pool_allocator,
97+
self.running_batch,
98+
self.new_token_ratio,
99+
self.max_prefill_tokens,
100+
self.chunked_prefill_size,
101+
running_bs if self.is_mixed_chunk else 0,
102+
self.priority_scheduling_preemption_threshold,
103+
prefill_max_requests=self.server_args.prefill_max_requests,
104+
dllm_config=self.dllm_config,
105+
)
106+
107+
def _process_dllm_batches(self: Scheduler, adder: PrefillAdder) -> ForwardMode:
108+
"""Process prefill or decode batches for DLLM."""
109+
forward_mode = ForwardMode.DLLM_EXTEND
110+
111+
# Try prefill batch first
112+
prefill_reqs = self.dllm_manager.get_prefill_requests()
113+
if prefill_reqs:
114+
self._process_batch_by_phase(
115+
adder,
116+
prefill_reqs,
117+
DllmReqPhase.STAGING_PREFILL,
118+
DllmReqPhase.INCOMING_PREFILL,
119+
)
120+
else:
121+
# Fall back to decode batch
122+
decode_reqs = self.dllm_manager.get_decode_requests()
123+
self._process_batch_by_phase(
124+
adder,
125+
decode_reqs,
126+
DllmReqPhase.STAGING_DECODE,
127+
DllmReqPhase.INCOMING_DECODE,
128+
)
129+
130+
return forward_mode
131+
132+
def _process_batch_by_phase(
133+
self,
134+
adder: PrefillAdder,
135+
batch: List[Req],
136+
staging_phase: DllmReqPhase,
137+
incoming_phase: DllmReqPhase,
138+
) -> None:
139+
"""Process a batch, separating staging and incoming requests."""
140+
staging_reqs = [req for req in batch if req.dllm_phase == staging_phase]
141+
if staging_reqs:
142+
staging_result = self.process_dllm_staging_reqs(adder, staging_reqs)
143+
if staging_result != AddReqResult.CONTINUE:
144+
return
145+
146+
incoming_reqs = [req for req in batch if req.dllm_phase == incoming_phase]
147+
if incoming_reqs:
148+
self.process_dllm_incoming_reqs(adder, incoming_reqs)
149+
150+
def _update_metrics_and_state_for_batch(
151+
self: Scheduler, can_run_list: List[Req], adder: PrefillAdder, running_bs: int
152+
) -> None:
153+
"""Update metrics and state for the batch."""
154+
if self.enable_metrics:
155+
for req in can_run_list:
156+
req.add_latency(RequestStage.PREFILL_WAITING)
157+
158+
if adder.preempt_list:
159+
for req in adder.preempt_list:
160+
self._add_request_to_queue(req)
161+
162+
if can_run_list:
163+
self.dllm_manager.add_staging_reqs(can_run_list)
164+
self.dllm_manager.increment_chunked_count()
165+
166+
self.adder = adder
167+
self.can_run_list = can_run_list
168+
self.running_bs = len(self.running_batch.reqs)
169+
170+
for req in can_run_list:
171+
if req.time_stats.forward_entry_time == 0:
172+
req.time_stats.forward_entry_time = time.perf_counter()
173+
if self.enable_metrics:
174+
self.metrics_collector.observe_queue_time(
175+
req.time_stats.get_queueing_time(),
176+
)
177+
178+
def _create_dllm_batch(
179+
self: Scheduler, can_run_list: List[Req], forward_mode: ForwardMode
180+
) -> ScheduleBatch:
181+
"""Create and prepare a new DLLM batch."""
182+
new_batch = ScheduleBatch.init_new(
183+
can_run_list,
184+
self.req_to_token_pool,
185+
self.token_to_kv_pool_allocator,
186+
self.tree_cache,
187+
self.model_config,
188+
self.enable_overlap,
189+
self.spec_algorithm,
190+
dllm_config=self.dllm_config,
191+
)
192+
new_batch.prepare_for_extend()
193+
new_batch.forward_mode = forward_mode
194+
new_batch.decoding_reqs = None
195+
return new_batch
196+
197+
def process_dllm_incoming_reqs(
198+
self: Scheduler, adder: PrefillAdder, reqs: List[Req]
199+
) -> AddReqResult:
200+
"""Process incoming DLLM requests with resource allocation and preemption."""
201+
res = AddReqResult.CONTINUE
202+
for req in reqs:
203+
# Check if batch is full
204+
running_bs = len(self.running_batch.reqs)
205+
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
206+
self.running_batch.batch_is_full = True
207+
208+
# Try preemption if batch is full
209+
if self.running_batch.batch_is_full:
210+
if not self.try_preemption or not adder.preempt_to_schedule(
211+
req, self.server_args
212+
):
213+
break
214+
215+
# Prepare and add request
216+
req.init_next_round_input(self.tree_cache)
217+
res = adder.add_one_req(
218+
req,
219+
has_chunked_req=True,
220+
truncation_align_size=self.truncation_align_size,
221+
)
222+
223+
if res != AddReqResult.CONTINUE:
224+
if res == AddReqResult.NO_TOKEN:
225+
self.running_batch.batch_is_full = True
226+
break
227+
228+
return res
229+
230+
def process_dllm_staging_reqs(
231+
self: Scheduler, adder: PrefillAdder, reqs: List[Req]
232+
) -> AddReqResult:
233+
"""Process staging DLLM requests with resource allocation."""
234+
for req in reqs:
235+
res = adder.add_dllm_staging_req(req)
236+
if res == AddReqResult.NO_TOKEN:
237+
return res
238+
239+
return AddReqResult.CONTINUE
240+
241+
242+
class DllmManager:
243+
"""
244+
Manager for Diffusion LLM request scheduling.
245+
246+
Maintains two queues:
247+
- waiting_queue: The requests waiting to be scheduled with max running requests limit
248+
- staging_queue: Requests allocated resources by PrefillAdder
249+
"""
250+
251+
def __init__(self, dllm_config: Optional[DllmConfig] = None):
252+
self.dllm_config = dllm_config
253+
self.max_running_reqs = (
254+
dllm_config.max_running_requests if dllm_config is not None else 1
255+
)
256+
self.waiting_queue: List[Req] = []
257+
self.staging_queue: List[Req] = []
258+
259+
def get_prefill_requests(self) -> List[Req]:
260+
"""Get all prefill requests from waiting queue."""
261+
return [req for req in self.waiting_queue if req.is_dllm_prefill()]
262+
263+
def get_decode_requests(self) -> List[Req]:
264+
"""Get all decode requests from waiting queue."""
265+
return [req for req in self.waiting_queue if not req.is_dllm_prefill()]
266+
267+
def add_waiting_reqs(self, reqs: Union[Req, List[Req]]) -> None:
268+
"""Add requests to waiting queue with redundancy check."""
269+
assert self.dllm_config is not None, "Diffusion LLM config is not set."
270+
271+
reqs_to_add = reqs if isinstance(reqs, list) else [reqs]
272+
273+
# Check for duplicate request IDs
274+
if self._has_duplicate_reqs(reqs_to_add):
275+
raise RuntimeError("Redundant requests detected in dLLM requests.")
276+
277+
self.waiting_queue.extend(reqs_to_add)
278+
279+
def add_staging_reqs(self, reqs: Union[Req, List[Req]]) -> None:
280+
"""Add requests to staging queue (allocated by PrefillAdder)."""
281+
reqs_to_add = reqs if isinstance(reqs, list) else [reqs]
282+
self.staging_queue.extend(reqs_to_add)
283+
284+
def _has_duplicate_reqs(self, reqs: List[Req]) -> bool:
285+
"""Check if any request ID already exists in waiting queue."""
286+
existing_rids: Set[str] = {r.rid for r in self.waiting_queue}
287+
return any(req.rid in existing_rids for req in reqs)
288+
289+
def any_staging_reqs(self) -> bool:
290+
"""Check if there are requests in staging queue."""
291+
return self.dllm_config is not None and len(self.staging_queue) > 0
292+
293+
def is_empty(self) -> bool:
294+
"""Check if both queues are empty or DLLM is not configured."""
295+
if self.dllm_config is None:
296+
return True
297+
return len(self.waiting_queue) == 0
298+
299+
def increment_chunked_count(self) -> None:
300+
"""Increment chunked count for all staging requests."""
301+
for req in self.staging_queue:
302+
req.is_chunked += 1
303+
304+
def filter_finished_reqs(self) -> None:
305+
"""Remove finished requests from both queues."""
306+
self.waiting_queue = [req for req in self.waiting_queue if not req.finished()]
307+
self.staging_queue = [req for req in self.staging_queue if not req.finished()]
308+
309+
def init_next_round(self) -> None:
310+
"""Initialize staging requests for next round and clear staging queue."""
311+
for req in self.staging_queue:
312+
req.init_next_round_input()
313+
self.staging_queue = []

0 commit comments

Comments
 (0)