Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
self fixing comments
  • Loading branch information
zhaochenyang20 committed Mar 3, 2026
commit 8effef7449d7c17309b1eeac8582a1b5df48693c
114 changes: 41 additions & 73 deletions python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import torch
from tqdm import tqdm

from sglang.multimodal_gen.benchmarks.datasets import RandomDataset, VBenchDataset
Expand All @@ -52,13 +53,13 @@ class BatchOutput:
latency: float = 0.0
latency_per_sample: float = 0.0
num_samples: int = 0
total_frames_or_pixels: int = 0
total_frames: int = 0
peak_memory_mb: float = 0.0
success: bool = False
error: str = ""


@dataclasses.dataclass
@dataclass
class BenchArgs:
"""Benchmark configuration for multimodal generation."""

Expand All @@ -83,8 +84,6 @@ class BenchArgs:

# Benchmark Execution
skip_warmup: bool = False
profile: bool = False
num_runs: int = 1
output_file: str = ""
disable_tqdm: bool = False

Expand Down Expand Up @@ -146,21 +145,16 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Total number of prompts to benchmark",
)
parser.add_argument(
"--batch-size", type=int, default=1, help="Batch size per generation call"
"--batch-size",
type=int,
default=1,
help="Batch size per generation call (currently only bs=1 is supported)",
)

# Benchmark Execution
parser.add_argument(
"--skip-warmup", action="store_true", help="Skip warmup batch"
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable torch profiler (use env var SGLANG_TORCH_PROFILER_DIR)",
)
parser.add_argument(
"--num-runs", type=int, default=1, help="Number of benchmark runs"
)
parser.add_argument(
"--output-file",
type=str,
Expand All @@ -180,19 +174,16 @@ def from_cli_args(cls, args: argparse.Namespace):
return cls(**{attr: getattr(args, attr) for attr in attrs})


def initialize_engine(server_args: ServerArgs) -> Any:
def initialize_engine(server_args: ServerArgs) -> DiffGenerator:
"""Initialize diffusion pipeline engine."""
logger.info("Initializing engine...")

# Initialize DiffGenerator in local mode
engine = DiffGenerator.from_server_args(server_args, local_mode=True)

logger.info(f"Engine initialized successfully")
logger.info("Engine initialized successfully")
return engine


def generate_batch(
engine: Any,
engine: DiffGenerator,
bench_args: BenchArgs,
prompts: List[str],
user_sampling_params: Dict[str, Any],
Expand All @@ -201,52 +192,48 @@ def generate_batch(
output = BatchOutput()
start_time = time.perf_counter()

try:
# Generate using DiffGenerator
for i, prompt in enumerate(prompts):
torch.cuda.reset_peak_memory_stats()

for prompt in prompts:
try:
sampling_params_kwargs = dict(user_sampling_params)
sampling_params_kwargs["prompt"] = prompt
result = engine.generate(sampling_params_kwargs=sampling_params_kwargs)

# Extract metrics from result
if result is not None:
if isinstance(result, list):
output.total_frames_or_pixels += len(result)
output.total_frames += len(result)
else:
output.total_frames_or_pixels += 1

output.latency = time.perf_counter() - start_time
output.latency_per_sample = output.latency / len(prompts)
output.num_samples = len(prompts)
output.success = True

logger.debug(
f"Batch generated: {len(prompts)} samples in {output.latency:.2f}s"
)

except Exception as e:
output.latency = time.perf_counter() - start_time
output.success = False
output.error = str(e)
logger.error(f"Batch generation failed: {e}")
output.total_frames += 1
output.num_samples += 1
except Exception as e:
logger.error(f"Generation failed for prompt '{prompt[:50]}...': {e}")
output.error = str(e)

output.latency = time.perf_counter() - start_time
output.latency_per_sample = output.latency / len(prompts) if prompts else 0.0
output.success = output.num_samples > 0
output.peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)

logger.debug(
f"Batch generated: {output.num_samples}/{len(prompts)} samples in {output.latency:.2f}s"
)

return output


def calculate_metrics(
outputs: List[BatchOutput],
total_duration: float,
task_type: str,
resolution: Tuple[int, int, int],
num_requests: int,
) -> Dict[str, Any]:
"""Calculate generation-specific throughput metrics."""
successful = [o for o in outputs if o.success]
num_success = len(successful)
total_frames = sum(o.total_frames_or_pixels for o in successful)
num_success = sum(o.num_samples for o in successful)
total_frames = sum(o.total_frames for o in successful)
peak_memory = max((o.peak_memory_mb for o in outputs), default=0)

# Resolution-based calculations
width, height, frames = resolution
pixels_per_sample = width * height * frames
total_pixels = num_success * pixels_per_sample
Expand Down Expand Up @@ -283,10 +270,8 @@ def throughput_test(
configure_logger(server_args=server_args)
logger.info("Starting offline throughput benchmark...")

# Initialize engine
engine = initialize_engine(server_args)

# Load dataset
logger.info(f"Loading {bench_args.dataset} dataset...")
if bench_args.dataset == "vbench":
bench_args.task_name = engine.server_args.pipeline_config.task_type
Expand All @@ -296,7 +281,6 @@ def throughput_test(
else:
raise ValueError(f"Unknown dataset: {bench_args.dataset}")

# Prepare sampling parameters
sampling_params = {
"guidance_scale": bench_args.guidance_scale,
"num_inference_steps": bench_args.num_inference_steps,
Expand All @@ -308,39 +292,30 @@ def throughput_test(
if bench_args.disable_safety_checker:
sampling_params["safety_checker"] = None

# Warmup (optional)
if not bench_args.skip_warmup:
logger.info("Running warmup batch...")
warmup_indices = list(range(min(bench_args.batch_size, len(dataset))))
warmup_batch = dataset.get_batch(warmup_indices)
warmup_prompts = [item["prompt"] for item in warmup_batch]
warmup_count = min(bench_args.batch_size, len(dataset))
warmup_prompts = [dataset[i].prompt for i in range(warmup_count)]
generate_batch(engine, bench_args, warmup_prompts, sampling_params)

# Main benchmark loop
logger.info(f"Running benchmark with {bench_args.num_prompts} prompts...")
outputs: List[BatchOutput] = []
total_indices = list(range(min(bench_args.num_prompts, len(dataset))))
total_count = min(bench_args.num_prompts, len(dataset))
all_prompts = [dataset[i].prompt for i in range(total_count)]

start_time = time.perf_counter()

# Process in batches
num_batches = (
len(total_indices) + bench_args.batch_size - 1
) // bench_args.batch_size
num_batches = (total_count + bench_args.batch_size - 1) // bench_args.batch_size
pbar = tqdm(
total=num_batches,
disable=bench_args.disable_tqdm,
desc="Benchmark",
)

for batch_start in range(0, len(total_indices), bench_args.batch_size):
batch_end = min(batch_start + bench_args.batch_size, len(total_indices))
batch_indices = total_indices[batch_start:batch_end]

batch_items = dataset.get_batch(batch_indices)
batch_prompts = [item["prompt"] for item in batch_items]
for batch_start in range(0, total_count, bench_args.batch_size):
batch_end = min(batch_start + bench_args.batch_size, total_count)
batch_prompts = all_prompts[batch_start:batch_end]

# Generate batch
batch_output = generate_batch(
engine, bench_args, batch_prompts, sampling_params
)
Expand All @@ -351,24 +326,20 @@ def throughput_test(
pbar.close()
total_duration = time.perf_counter() - start_time

# Calculate metrics
resolution = (bench_args.width, bench_args.height, bench_args.num_frames)
metrics = calculate_metrics(
outputs,
total_duration,
task_type="unknown",
resolution=resolution,
num_requests=len(total_indices),
num_requests=total_count,
)

# Display results
display_results(
metrics,
bench_args,
model_path=server_args.model_path,
)

# Save results
if bench_args.output_file:
save_results(metrics, bench_args, server_args)

Expand Down Expand Up @@ -426,7 +397,7 @@ def save_results(
"metadata": {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"model_path": server_args.model_path,
"task_type": "unknown",
"task_type": bench_args.task_name,
"backend": "engine",
},
"configuration": {
Expand Down Expand Up @@ -458,14 +429,11 @@ def main():

args = parser.parse_args()

# Create ServerArgs and BenchArgs
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)

# Set global server args
set_global_server_args(server_args)

# Run benchmark
result = throughput_test(server_args, bench_args)

return result
Expand Down
Loading
Loading