11from __future__ import annotations
22
3+ import dataclasses
34import logging
45import time
56from collections import defaultdict
67from contextlib import contextmanager
7- from typing import TYPE_CHECKING , Dict , List , Optional , Union
8+ from typing import TYPE_CHECKING , Dict , Optional , Union
89
910from sglang .srt .disaggregation .kv_events import EventPublisherFactory , KVEventBatch
1011from sglang .srt .disaggregation .utils import DisaggregationMode
2021 QueueMetrics ,
2122 SpeculativeMetrics ,
2223)
23- from sglang .srt .managers .schedule_policy import PrefillAdder
24- from sglang .srt .managers .scheduler import Req , ScheduleBatch
24+ from sglang .srt .managers .scheduler import ScheduleBatch
2525from sglang .srt .managers .utils import GenerationBatchResult
2626from sglang .srt .metrics .collector import (
2727 SchedulerMetricsCollector ,
4242ENABLE_METRICS_DEVICE_TIMER = envs .SGLANG_ENABLE_METRICS_DEVICE_TIMER .get ()
4343
4444
45+ @dataclasses .dataclass
46+ class PrefillStats :
47+ """Stats for logging prefill batch metrics."""
48+
49+ log_input_tokens : int
50+ log_hit_tokens : int
51+ new_token_ratio : float
52+ running_bs : int
53+ num_new_seqs : int # len(can_run_list)
54+
55+
4556class KvMetrics :
4657 def __init__ (self ):
4758 self .request_active_slots = None
@@ -148,21 +159,18 @@ def reset_metrics(self: Scheduler):
148159
149160 def log_prefill_stats (
150161 self : Scheduler ,
151- adder : PrefillAdder ,
152- can_run_list : List [Req ],
153- running_bs : int ,
154- running_bs_offline_batch : int ,
162+ prefill_stats : PrefillStats ,
155163 can_run_cuda_graph : bool ,
156164 ):
157165 gap_latency = time .perf_counter () - self .last_prefill_stats_tic
158166 self .last_prefill_stats_tic = time .perf_counter ()
159167 self .last_input_throughput = self .last_prefill_tokens / gap_latency
160- self .last_prefill_tokens = adder .log_input_tokens
168+ self .last_prefill_tokens = prefill_stats .log_input_tokens
161169
162170 assert self .temp_prefill_info is None
163171 self .temp_prefill_info = dict (
164- adder_log_input_tokens = adder .log_input_tokens ,
165- adder_log_hit_tokens = adder .log_hit_tokens ,
172+ adder_log_input_tokens = prefill_stats .log_input_tokens ,
173+ adder_log_hit_tokens = prefill_stats .log_hit_tokens ,
166174 )
167175
168176 # TODO: generalize this for various memory pools
@@ -204,16 +212,16 @@ def log_prefill_stats(
204212 num_used , token_usage , _ , _ = self ._get_token_info ()
205213 token_usage_msg = f"token usage: { token_usage :.2f} , "
206214
207- self .stats .new_token_ratio = adder .new_token_ratio
215+ self .stats .new_token_ratio = prefill_stats .new_token_ratio
208216 iter_msg = f" [{ self .forward_ct + 1 } ]" if LOG_FORWARD_ITERS else ""
209217
210218 msg = (
211219 f"Prefill batch{ iter_msg } , "
212- f"#new-seq: { len ( can_run_list ) } , "
213- f"#new-token: { adder .log_input_tokens } , "
214- f"#cached-token: { adder .log_hit_tokens } , "
220+ f"#new-seq: { prefill_stats . num_new_seqs } , "
221+ f"#new-token: { prefill_stats .log_input_tokens } , "
222+ f"#cached-token: { prefill_stats .log_hit_tokens } , "
215223 f"{ token_usage_msg } "
216- f"#running-req: { running_bs } , "
224+ f"#running-req: { prefill_stats . running_bs } , "
217225 f"#queue-req: { len (self .waiting_queue )} , "
218226 )
219227
@@ -240,13 +248,13 @@ def log_prefill_stats(
240248
241249 if self .enable_metrics :
242250 # Basics
243- total_tokens = adder .log_input_tokens + adder .log_hit_tokens
251+ total_tokens = prefill_stats .log_input_tokens + prefill_stats .log_hit_tokens
244252 cache_hit_rate = (
245- adder .log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
253+ prefill_stats .log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
246254 )
247255
248- self .stats .num_running_reqs = running_bs
249- self .stats .num_running_reqs_offline_batch = running_bs_offline_batch
256+ self .stats .num_running_reqs = prefill_stats . running_bs
257+ self .stats .num_running_reqs_offline_batch = 0
250258 self .stats .num_used_tokens = num_used
251259 self .stats .token_usage = token_usage
252260 if self .is_hybrid_swa :
0 commit comments