Skip to content

Commit 93fca0b

Browse files
authored
Fix wrong prefill log. (sgl-project#18570)
1 parent 2bfab1b commit 93fca0b

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
from typing import Any, Dict
9393

9494
from sglang.srt.configs.model_config import ModelConfig
95+
from sglang.srt.managers.scheduler_metrics_mixin import PrefillStats
9596
from sglang.srt.speculative.eagle_info import EagleDraftInput
9697
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
9798

@@ -1304,6 +1305,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
13041305

13051306
# Metrics
13061307
dp_cooperation_info: Optional[DPCooperationInfo] = None
1308+
prefill_stats: Optional[PrefillStats] = None
13071309

13081310
@classmethod
13091311
def init_new(
@@ -2243,6 +2245,7 @@ def copy(self):
22432245
mamba_track_mask=self.mamba_track_mask,
22442246
mamba_track_seqlens=self.mamba_track_seqlens,
22452247
dp_cooperation_info=self.dp_cooperation_info,
2248+
prefill_stats=self.prefill_stats,
22462249
)
22472250

22482251
def maybe_evict_swa(self):

python/sglang/srt/managers/scheduler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
152152
from sglang.srt.managers.scheduler_metrics_mixin import (
153153
RECORD_STEP_TIME,
154+
PrefillStats,
154155
SchedulerMetricsMixin,
155156
)
156157
from sglang.srt.managers.scheduler_output_processor_mixin import (
@@ -2121,6 +2122,15 @@ def _get_new_batch_prefill_raw(
21212122

21222123
new_batch.prepare_for_extend()
21232124

2125+
# Record prefill stats for logging after forward
2126+
new_batch.prefill_stats = PrefillStats(
2127+
log_input_tokens=adder.log_input_tokens,
2128+
log_hit_tokens=adder.log_hit_tokens,
2129+
new_token_ratio=adder.new_token_ratio,
2130+
running_bs=len(self.running_batch.reqs),
2131+
num_new_seqs=len(can_run_list),
2132+
)
2133+
21242134
# Mixed-style chunked prefill
21252135
if (
21262136
self.is_mixed_chunk

python/sglang/srt/managers/scheduler_metrics_mixin.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
import logging
45
import time
56
from collections import defaultdict
67
from contextlib import contextmanager
7-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
8+
from typing import TYPE_CHECKING, Dict, Optional, Union
89

910
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
1011
from sglang.srt.disaggregation.utils import DisaggregationMode
@@ -20,8 +21,7 @@
2021
QueueMetrics,
2122
SpeculativeMetrics,
2223
)
23-
from sglang.srt.managers.schedule_policy import PrefillAdder
24-
from sglang.srt.managers.scheduler import Req, ScheduleBatch
24+
from sglang.srt.managers.scheduler import ScheduleBatch
2525
from sglang.srt.managers.utils import GenerationBatchResult
2626
from sglang.srt.metrics.collector import (
2727
SchedulerMetricsCollector,
@@ -42,6 +42,17 @@
4242
ENABLE_METRICS_DEVICE_TIMER = envs.SGLANG_ENABLE_METRICS_DEVICE_TIMER.get()
4343

4444

45+
@dataclasses.dataclass
46+
class PrefillStats:
47+
"""Stats for logging prefill batch metrics."""
48+
49+
log_input_tokens: int
50+
log_hit_tokens: int
51+
new_token_ratio: float
52+
running_bs: int
53+
num_new_seqs: int # len(can_run_list)
54+
55+
4556
class KvMetrics:
4657
def __init__(self):
4758
self.request_active_slots = None
@@ -148,21 +159,18 @@ def reset_metrics(self: Scheduler):
148159

149160
def log_prefill_stats(
150161
self: Scheduler,
151-
adder: PrefillAdder,
152-
can_run_list: List[Req],
153-
running_bs: int,
154-
running_bs_offline_batch: int,
162+
prefill_stats: PrefillStats,
155163
can_run_cuda_graph: bool,
156164
):
157165
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
158166
self.last_prefill_stats_tic = time.perf_counter()
159167
self.last_input_throughput = self.last_prefill_tokens / gap_latency
160-
self.last_prefill_tokens = adder.log_input_tokens
168+
self.last_prefill_tokens = prefill_stats.log_input_tokens
161169

162170
assert self.temp_prefill_info is None
163171
self.temp_prefill_info = dict(
164-
adder_log_input_tokens=adder.log_input_tokens,
165-
adder_log_hit_tokens=adder.log_hit_tokens,
172+
adder_log_input_tokens=prefill_stats.log_input_tokens,
173+
adder_log_hit_tokens=prefill_stats.log_hit_tokens,
166174
)
167175

168176
# TODO: generalize this for various memory pools
@@ -204,16 +212,16 @@ def log_prefill_stats(
204212
num_used, token_usage, _, _ = self._get_token_info()
205213
token_usage_msg = f"token usage: {token_usage:.2f}, "
206214

207-
self.stats.new_token_ratio = adder.new_token_ratio
215+
self.stats.new_token_ratio = prefill_stats.new_token_ratio
208216
iter_msg = f" [{self.forward_ct + 1}]" if LOG_FORWARD_ITERS else ""
209217

210218
msg = (
211219
f"Prefill batch{iter_msg}, "
212-
f"#new-seq: {len(can_run_list)}, "
213-
f"#new-token: {adder.log_input_tokens}, "
214-
f"#cached-token: {adder.log_hit_tokens}, "
220+
f"#new-seq: {prefill_stats.num_new_seqs}, "
221+
f"#new-token: {prefill_stats.log_input_tokens}, "
222+
f"#cached-token: {prefill_stats.log_hit_tokens}, "
215223
f"{token_usage_msg}"
216-
f"#running-req: {running_bs}, "
224+
f"#running-req: {prefill_stats.running_bs}, "
217225
f"#queue-req: {len(self.waiting_queue)}, "
218226
)
219227

@@ -240,13 +248,13 @@ def log_prefill_stats(
240248

241249
if self.enable_metrics:
242250
# Basics
243-
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
251+
total_tokens = prefill_stats.log_input_tokens + prefill_stats.log_hit_tokens
244252
cache_hit_rate = (
245-
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
253+
prefill_stats.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
246254
)
247255

248-
self.stats.num_running_reqs = running_bs
249-
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
256+
self.stats.num_running_reqs = prefill_stats.running_bs
257+
self.stats.num_running_reqs_offline_batch = 0
250258
self.stats.num_used_tokens = num_used
251259
self.stats.token_usage = token_usage
252260
if self.is_hybrid_swa:

python/sglang/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,7 @@ def process_batch_result_prefill(
343343
if self.current_scheduler_metrics_enabled:
344344
can_run_cuda_graph = getattr(result, "can_run_cuda_graph", False)
345345
self.log_prefill_stats(
346-
adder=self.adder,
347-
can_run_list=self.can_run_list,
348-
running_bs=self.running_bs,
349-
running_bs_offline_batch=0,
346+
prefill_stats=batch.prefill_stats,
350347
can_run_cuda_graph=can_run_cuda_graph,
351348
)
352349

@@ -422,10 +419,7 @@ def process_batch_result_dllm(
422419
if self.current_scheduler_metrics_enabled:
423420
can_run_cuda_graph = getattr(result, "can_run_cuda_graph", False)
424421
self.log_prefill_stats(
425-
adder=self.adder,
426-
can_run_list=self.can_run_list,
427-
running_bs=self.running_bs,
428-
running_bs_offline_batch=0,
422+
prefill_stats=batch.prefill_stats,
429423
can_run_cuda_graph=can_run_cuda_graph,
430424
)
431425

0 commit comments

Comments
 (0)