Skip to content
Prev Previous commit
Next Next commit
more style fix
  • Loading branch information
sniper35 committed Feb 8, 2026
commit 50745a3f88579d3b8f4356aef57b81dafd952199
38 changes: 21 additions & 17 deletions examples/diffusion_router/bench_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import subprocess
import sys
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Iterable

import requests

Expand All @@ -39,7 +39,10 @@ def _require_non_empty_model(model: str) -> str:


def _wait_for_health(
url: str, timeout: int, label: str, proc: subprocess.Popen | None = None,
url: str,
timeout: int,
label: str,
proc: subprocess.Popen | None = None,
) -> None:
start = time.time()
last_print = 0.0
Expand Down Expand Up @@ -106,10 +109,7 @@ def _reserve_available_port(host: str, preferred_port: int, used_ports: set[int]
used_ports.add(port)
return port

raise RuntimeError(
f"Unable to reserve a free port for host {host}. "
f"Preferred start={preferred_port}."
)
raise RuntimeError(f"Unable to reserve a free port for host {host}. " f"Preferred start={preferred_port}.")


def _parse_gpu_id_list(raw: str) -> list[str]:
Expand Down Expand Up @@ -178,9 +178,13 @@ def main() -> int:
parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.")
parser.add_argument("--router-host", type=str, default="127.0.0.1", help="Router bind host.")
parser.add_argument("--router-port", type=int, default=30080, help="Router port.")
parser.add_argument("--routing-algorithm", type=str, default="least-request",
choices=["least-request", "round-robin", "random"],
help="Load-balancing algorithm for the router.")
parser.add_argument(
"--routing-algorithm",
type=str,
default="least-request",
choices=["least-request", "round-robin", "random"],
help="Load-balancing algorithm for the router.",
)
parser.add_argument("--router-verbose", action="store_true", help="Enable router verbose logging.")
parser.add_argument("--router-extra-args", type=str, default="", help="Extra args for the router demo script.")

Expand Down Expand Up @@ -250,11 +254,10 @@ def main() -> int:

try:
import sglang # noqa: F401
except ImportError:
except ImportError as exc:
raise RuntimeError(
"sglang is not installed.\n"
"Install with: uv pip install \"sglang[diffusion]\" --prerelease=allow"
)
"sglang is not installed.\n" 'Install with: uv pip install "sglang[diffusion]" --prerelease=allow'
) from exc
env = dict(os.environ)

worker_urls = list(args.worker_urls)
Expand Down Expand Up @@ -295,9 +298,7 @@ def main() -> int:

for i, _ in enumerate(worker_urls):
preferred_master = args.worker_master_port_base + i * args.worker_internal_port_stride
preferred_scheduler = (
args.worker_scheduler_port_base + i * args.worker_internal_port_stride
)
preferred_scheduler = args.worker_scheduler_port_base + i * args.worker_internal_port_stride
master_port = _reserve_available_port(args.worker_host, preferred_master, reserved_ports)
scheduler_port = _reserve_available_port(args.worker_host, preferred_scheduler, reserved_ports)
worker_internal_ports.append((master_port, scheduler_port))
Expand Down Expand Up @@ -371,7 +372,10 @@ def main() -> int:
flush=True,
)

print(f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...", flush=True)
print(
f"[bench] Waiting for {len(worker_urls)} worker(s) to become healthy (this may take several minutes)...",
flush=True,
)
for i, url in enumerate(worker_urls):
_wait_for_health(url, args.wait_timeout, f"worker {url}", proc=processes[i])

Expand Down
111 changes: 74 additions & 37 deletions examples/diffusion_router/bench_routing_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def _require_non_empty_model(model: str) -> str:


def main() -> int:
parser = argparse.ArgumentParser(
description="Compare routing algorithms by running bench_router.py for each."
)
parser = argparse.ArgumentParser(description="Compare routing algorithms by running bench_router.py for each.")
parser.add_argument("--model", type=str, required=True, help="Diffusion model HF ID or local path.")
parser.add_argument(
"--algorithms", nargs="+", default=ALL_ALGORITHMS,
choices=ALL_ALGORITHMS, help="Algorithms to compare (default: all three).",
"--algorithms",
nargs="+",
default=ALL_ALGORITHMS,
choices=ALL_ALGORITHMS,
help="Algorithms to compare (default: all three).",
)
parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results.")

Expand Down Expand Up @@ -100,25 +101,44 @@ def main() -> int:
out_file = output_dir / f"bench_{algo}.json"

bench_cmd = [
py, "examples/diffusion_router/bench_router.py",
"--model", args.model,
"--routing-algorithm", algo,
"--num-workers", str(args.num_workers),
"--num-prompts", str(args.num_prompts),
"--max-concurrency", str(args.max_concurrency),
"--num-gpus-per-worker", str(args.num_gpus_per_worker),
"--worker-base-port", str(args.worker_base_port),
"--worker-port-stride", str(args.worker_port_stride),
"--worker-master-port-base", str(args.worker_master_port_base),
"--worker-scheduler-port-base", str(args.worker_scheduler_port_base),
"--worker-internal-port-stride", str(args.worker_internal_port_stride),
"--router-host", args.router_host,
"--router-port", str(args.router_port),
"--dataset", args.dataset,
"--request-rate", str(args.request_rate),
"--wait-timeout", str(args.wait_timeout),
"--log-level", args.log_level,
"--output-file", str(out_file),
py,
"examples/diffusion_router/bench_router.py",
"--model",
args.model,
"--routing-algorithm",
algo,
"--num-workers",
str(args.num_workers),
"--num-prompts",
str(args.num_prompts),
"--max-concurrency",
str(args.max_concurrency),
"--num-gpus-per-worker",
str(args.num_gpus_per_worker),
"--worker-base-port",
str(args.worker_base_port),
"--worker-port-stride",
str(args.worker_port_stride),
"--worker-master-port-base",
str(args.worker_master_port_base),
"--worker-scheduler-port-base",
str(args.worker_scheduler_port_base),
"--worker-internal-port-stride",
str(args.worker_internal_port_stride),
"--router-host",
args.router_host,
"--router-port",
str(args.router_port),
"--dataset",
args.dataset,
"--request-rate",
str(args.request_rate),
"--wait-timeout",
str(args.wait_timeout),
"--log-level",
args.log_level,
"--output-file",
str(out_file),
]
if args.worker_urls:
bench_cmd += ["--worker-urls", *args.worker_urls]
Expand Down Expand Up @@ -178,14 +198,23 @@ def main() -> int:
data = results.get(algo, {})
if "error" in data:
parsed[algo] = None
csv_rows.append({
"algorithm": algo, "throughput_qps": "", "latency_mean": "",
"latency_median": "", "latency_p99": "", "duration": "",
"completed_requests": "", "failed_requests": "",
"throughput_qps_delta_pct": "", "latency_mean_delta_pct": "",
"latency_median_delta_pct": "", "latency_p99_delta_pct": "",
"error": data["error"],
})
csv_rows.append(
{
"algorithm": algo,
"throughput_qps": "",
"latency_mean": "",
"latency_median": "",
"latency_p99": "",
"duration": "",
"completed_requests": "",
"failed_requests": "",
"throughput_qps_delta_pct": "",
"latency_mean_delta_pct": "",
"latency_median_delta_pct": "",
"latency_p99_delta_pct": "",
"error": data["error"],
}
)
continue

row = {
Expand Down Expand Up @@ -260,11 +289,19 @@ def _fmt_int(v):

csv_path = output_dir / "routing_algorithm_comparison.csv"
fieldnames = [
"algorithm", "throughput_qps", "throughput_qps_delta_pct",
"latency_mean", "latency_mean_delta_pct",
"latency_median", "latency_median_delta_pct",
"latency_p99", "latency_p99_delta_pct",
"duration", "completed_requests", "failed_requests", "error",
"algorithm",
"throughput_qps",
"throughput_qps_delta_pct",
"latency_mean",
"latency_mean_delta_pct",
"latency_median",
"latency_median_delta_pct",
"latency_p99",
"latency_p99_delta_pct",
"duration",
"completed_requests",
"failed_requests",
"error",
]
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
Expand Down
10 changes: 7 additions & 3 deletions examples/diffusion_router/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ def main():
parser.add_argument("--timeout", type=float, default=None, help="Request timeout in seconds")
parser.add_argument("--health-check-interval", type=int, default=10, help="Seconds between health checks")
parser.add_argument("--health-check-failure-threshold", type=int, default=3, help="Failures before quarantine")
parser.add_argument("--routing-algorithm", type=str, default="least-request",
choices=["least-request", "round-robin", "random"],
help="Load-balancing algorithm")
parser.add_argument(
"--routing-algorithm",
type=str,
default="least-request",
choices=["least-request", "round-robin", "random"],
help="Load-balancing algorithm",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
args = parser.parse_args()

Expand Down
26 changes: 18 additions & 8 deletions miles/router/diffusion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,16 @@ async def _send(worker_url):
@staticmethod
def _sanitize_response_headers(headers) -> dict:
"""Remove hop-by-hop and encoding headers that no longer match buffered content."""
hop_by_hop = {"connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers",
"transfer-encoding", "upgrade"}
hop_by_hop = {
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
}
dropped = {"content-length", "content-encoding"}
return {k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped}

Expand Down Expand Up @@ -233,12 +241,14 @@ async def health_workers(self, request: Request):
"""Per-worker health and load information."""
workers = []
for url, count in self.worker_request_counts.items():
workers.append({
"url": url,
"active_requests": count,
"is_dead": url in self.dead_workers,
"consecutive_failures": self.worker_failure_counts.get(url, 0),
})
workers.append(
{
"url": url,
"active_requests": count,
"is_dead": url in self.dead_workers,
"consecutive_failures": self.worker_failure_counts.get(url, 0),
}
)
return JSONResponse(content={"workers": workers})

async def update_weights_from_disk(self, request: Request):
Expand Down