From 2d81aedcf05caa4905ad1ac1e9f3e6fb9fbac833 Mon Sep 17 00:00:00 2001 From: alphabetc1 <2508695655@qq.com> Date: Wed, 18 Feb 2026 00:37:58 +0800 Subject: [PATCH 1/2] feat: diffusion router --- .codespellrc | 2 + .gitignore | 6 +- .pre-commit-config.yaml | 38 ++ README.md | 213 ++++++- docs/update_weights_from_disk.md | 65 +++ pyproject.toml | 36 ++ src/sglang_diffusion_routing/__init__.py | 5 + src/sglang_diffusion_routing/__main__.py | 4 + src/sglang_diffusion_routing/cli/__init__.py | 1 + src/sglang_diffusion_routing/cli/main.py | 126 ++++ .../router/__init__.py | 5 + .../router/diffusion_router.py | 436 ++++++++++++++ .../diffusion_router/bench_router.py | 536 ++++++++++++++++++ .../bench_routing_algorithms.py | 336 +++++++++++ tests/unit/test_cli.py | 56 ++ tests/unit/test_diffusion_router.py | 229 ++++++++ tests/unit/test_router_endpoints.py | 68 +++ 17 files changed, 2151 insertions(+), 11 deletions(-) create mode 100644 .codespellrc create mode 100644 .pre-commit-config.yaml create mode 100644 docs/update_weights_from_disk.md create mode 100644 pyproject.toml create mode 100644 src/sglang_diffusion_routing/__init__.py create mode 100644 src/sglang_diffusion_routing/__main__.py create mode 100644 src/sglang_diffusion_routing/cli/__init__.py create mode 100644 src/sglang_diffusion_routing/cli/main.py create mode 100644 src/sglang_diffusion_routing/router/__init__.py create mode 100644 src/sglang_diffusion_routing/router/diffusion_router.py create mode 100644 tests/benchmarks/diffusion_router/bench_router.py create mode 100644 tests/benchmarks/diffusion_router/bench_routing_algorithms.py create mode 100644 tests/unit/test_cli.py create mode 100644 tests/unit/test_diffusion_router.py create mode 100644 tests/unit/test_router_endpoints.py diff --git a/.codespellrc b/.codespellrc new file mode 100644 index 0000000..64f2a5d --- /dev/null +++ b/.codespellrc @@ -0,0 +1,2 @@ +[codespell] +ignore-words-list = te diff --git a/.gitignore b/.gitignore index b7faf40..d6d15a0 100644 --- a/.gitignore +++ b/.gitignore @@ -182,11 +182,11 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder -# .vscode/ +.vscode/ # Ruff stuff: .ruff_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4a9a802 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-toml + - id: check-yaml + - id: check-ast + - id: check-added-large-files + - id: check-merge-conflict + - id: debug-statements + - id: detect-private-key + - id: no-commit-to-branch + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.7 + hooks: + - id: ruff + args: + - --select=F401,F821 + - --fix + + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + args: ['--config', '.codespellrc'] diff --git a/README.md b/README.md index 2f9c2dd..6a23eb0 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,210 @@ # sglang-diffusion-routing -A demonstrative example of running SGLang Diffusion with a DP router, which supports `generation` (a lot of methods, including [SDE/CPS](https://github.com/sgl-project/sglang/pull/18806)), `update_weights_from_disk` in PR [18306](https://github.com/sgl-project/sglang/pull/18306), and `health_check`. +A lightweight router for SGLang diffusion workers. -1. Copy all the codes of https://github.com/radixark/miles/pull/544 to here with sincere acknowledgment. -2. Write up a detailed README on how to use SGLang Diffusion Router to launch multiple instances and send requests. +It provides worker registration, load balancing, health checking, and request proxying for diffusion generation APIs. -For example, given that we can make a Python binding of the sglang-d router: +## Highlights -1. pip install sglang-d-router (Only for local development right now, clone the repository and run `pip install .` from the root directory. No need to make a PyPi) -2. pip install "sglang[diffusion]" -3. launching command (how to use sglang-d-router to launch n sglang diffusion servers) -4. Sending demonstrative requests +- `least-request` routing by default, with `round-robin` and `random`. +- Background health checks with quarantine after repeated failures. +- Router APIs for worker registration, health inspection, and proxy forwarding. +- `update_weights_from_disk` broadcast to all healthy workers. + +## Installation + +From repository root: + +```bash +python3 -m venv .venv +. .venv/bin/activate +pip install . +``` + +Development install: + +```bash +pip install -e . +``` + +Run tests: + +```bash +pip install pytest +pytest tests/unit -v +``` + +Workers require SGLang diffusion support: + +```bash +pip install "sglang[diffusion]" +``` + +## Quick Start + +### 1) Start diffusion workers + +```bash +# worker 1 +CUDA_VISIBLE_DEVICES=0 sglang serve \ + --model-path stabilityai/stable-diffusion-3-medium-diffusers \ + --num-gpus 1 \ + --host 127.0.0.1 \ + --port 30000 + +# worker 2 +CUDA_VISIBLE_DEVICES=1 sglang serve \ + --model-path stabilityai/stable-diffusion-3-medium-diffusers \ + --num-gpus 1 \ + --host 127.0.0.1 \ + --port 30001 +``` + +### 2) Start the router + +Script entry: + +```bash +sglang-d-router --port 30080 \ + --worker-urls http://localhost:30000 http://localhost:30001 +``` + +Module entry: + +```bash +python -m sglang_diffusion_routing --port 30080 \ + --worker-urls http://localhost:30000 http://localhost:30001 +``` + +Or start empty and add workers later: + +```bash +sglang-d-router --port 30080 +curl -X POST "http://localhost:30080/add_worker?url=http://localhost:30000" +``` + +### 3) Test the router + +```bash +# Check router health +curl http://localhost:30080/health + +# List registered workers +curl http://localhost:30080/list_workers + +# Image generation request (SD3) +curl -X POST http://localhost:30080/generate \ + -H "Content-Type: application/json" \ + -d '{ + "model": "stabilityai/stable-diffusion-3-medium-diffusers", + "prompt": "a cute cat", + "num_images": 1 + }' + +# Video generation request +curl -X POST http://localhost:30080/generate_video \ + -H "Content-Type: application/json" \ + -d '{ + "model": "stabilityai/stable-video-diffusion", + "prompt": "a flowing river" + }' + +# Check per-worker health and load +curl http://localhost:30080/health_workers +``` + +## Router API + +- `POST /add_worker`: add worker via query (`?url=`) or JSON body. +- `GET /list_workers`: list registered workers. +- `GET /health`: aggregated router health. +- `GET /health_workers`: per-worker health and active request counts. +- `POST /generate`: forwards to worker `/v1/images/generations`. +- `POST /generate_video`: forwards to worker `/v1/videos`. +- `POST /update_weights_from_disk`: broadcast to healthy workers. +- `GET|POST|PUT|DELETE /{path}`: catch-all proxy forwarding. + +## `update_weights_from_disk` behavior + +Full details: [docs/update_weights_from_disk.md](docs/update_weights_from_disk.md) + +- The router forwards request payloads as-is to each healthy worker. +- The router does not validate payload schema; payload semantics are worker-defined. +- Worker servers must implement `POST /update_weights_from_disk`. + +Example: + +```bash +curl -X POST http://localhost:30080/update_weights_from_disk \ + -H "Content-Type: application/json" \ + -d '{"model_path": "/path/to/new/weights"}' +``` + +Response shape: + +```json +{ + "results": [ + { + "worker_url": "http://localhost:30000", + "status_code": 200, + "body": { + "ok": true + } + } + ] +} +``` + +## Benchmark Scripts + +Benchmark scripts are available under `tests/benchmarks/diffusion_router/` and are intended for manual runs. +They are not part of default unit test collection (`pytest tests/unit -v`). + +Single benchmark: + +```bash +python tests/benchmarks/diffusion_router/bench_router.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +``` + +Algorithm comparison: + +```bash +python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +``` + +## Project Layout + +```text +. +├── docs/ +│ └── update_weights_from_disk.md +├── src/sglang_diffusion_routing/ +│ ├── cli/ +│ └── router/ +├── tests/ +│ ├── benchmarks/ +│ │ └── diffusion_router/ +│ │ ├── bench_router.py +│ │ └── bench_routing_algorithms.py +│ └── unit/ +├── pyproject.toml +└── README.md +``` + +## Acknowledgment + +This project is derived from [radixark/miles#544](https://github.com/radixark/miles/pull/544). Thanks to the original authors for their work. + +## Notes + +- Quarantined workers are intentionally not auto-reintroduced. +- Router responses are fully buffered; streaming passthrough is not implemented. diff --git a/docs/update_weights_from_disk.md b/docs/update_weights_from_disk.md new file mode 100644 index 0000000..c6362b5 --- /dev/null +++ b/docs/update_weights_from_disk.md @@ -0,0 +1,65 @@ +# update_weights_from_disk + +This document describes `POST /update_weights_from_disk` behavior in this repository. + +## Router behavior + +The router does not validate or transform payload fields. +It forwards the original request body to every healthy worker and returns per-worker results. + +Payload semantics are therefore defined by the worker implementation, not by the router. + +## Requirements + +- Worker servers must implement `POST /update_weights_from_disk`. +- For SGLang workers, use a version that includes this endpoint. +- Weights must match your worker runtime expectations. + +## Basic example + +```bash +curl -X POST http://localhost:30080/update_weights_from_disk \ + -H "Content-Type: application/json" \ + -d '{"model_path": "/path/to/new/weights"}' +``` + +## Optional fields + +Some worker versions support optional fields such as `target_modules`: + +```bash +curl -X POST http://localhost:30080/update_weights_from_disk \ + -H "Content-Type: application/json" \ + -d '{"model_path": "/path/to/weights", "target_modules": ["transformer", "vae"]}' +``` + +If your worker version does not support extra fields, failure is returned by the worker side. + +## Response shape + +The router response includes one item per healthy worker: + +```json +{ + "results": [ + { + "worker_url": "http://localhost:10090", + "status_code": 200, + "body": { + "ok": true + } + }, + { + "worker_url": "http://localhost:10092", + "status_code": 500, + "body": { + "error": "worker-side failure" + } + } + ] +} +``` + +Notes: +- Quarantined workers are excluded from broadcast. +- Transport/runtime exceptions are surfaced as per-worker `status_code=502`. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..107ac21 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sglang-diffusion-routing" +version = "0.1.0" +description = "Load-balancing router for SGLang diffusion workers" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } +dependencies = [ + "fastapi>=0.110", + "httpx>=0.27", + "uvicorn>=0.30", +] +classifiers = [ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Intended Audience :: Developers", +] + +[project.scripts] +sglang-d-router = "sglang_diffusion_routing.cli.main:main" + +[tool.setuptools] +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests/unit"] diff --git a/src/sglang_diffusion_routing/__init__.py b/src/sglang_diffusion_routing/__init__.py new file mode 100644 index 0000000..b2bbffb --- /dev/null +++ b/src/sglang_diffusion_routing/__init__.py @@ -0,0 +1,5 @@ +"""Public package API for sglang diffusion routing.""" + +from sglang_diffusion_routing.router.diffusion_router import DiffusionRouter + +__all__ = ["DiffusionRouter"] diff --git a/src/sglang_diffusion_routing/__main__.py b/src/sglang_diffusion_routing/__main__.py new file mode 100644 index 0000000..a6ee04b --- /dev/null +++ b/src/sglang_diffusion_routing/__main__.py @@ -0,0 +1,4 @@ +from sglang_diffusion_routing.cli.main import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/sglang_diffusion_routing/cli/__init__.py b/src/sglang_diffusion_routing/cli/__init__.py new file mode 100644 index 0000000..ff751b6 --- /dev/null +++ b/src/sglang_diffusion_routing/cli/__init__.py @@ -0,0 +1 @@ +"""CLI package for sglang diffusion routing.""" diff --git a/src/sglang_diffusion_routing/cli/main.py b/src/sglang_diffusion_routing/cli/main.py new file mode 100644 index 0000000..3910826 --- /dev/null +++ b/src/sglang_diffusion_routing/cli/main.py @@ -0,0 +1,126 @@ +# This module is derived from radixark/miles#544. +# See README.md for full acknowledgment. + +from __future__ import annotations + +import argparse +import sys + +from sglang_diffusion_routing import DiffusionRouter + + +def _run_router_server( + args: argparse.Namespace, + worker_urls: list[str] | None = None, + log_prefix: str = "[router]", +) -> None: + try: + import uvicorn # type: ignore[import-not-found] + except ImportError as exc: + raise RuntimeError( + "uvicorn is required to run router. Install with: pip install uvicorn" + ) from exc + + worker_urls = list( + worker_urls if worker_urls is not None else args.worker_urls or [] + ) + router = DiffusionRouter(args, verbose=args.verbose) + for url in worker_urls: + router.register_worker(url) + + print(f"{log_prefix} starting router on {args.host}:{args.port}", flush=True) + print( + f"{log_prefix} workers: {list(router.worker_request_counts.keys()) or '(none - add via POST /add_worker)'}", + flush=True, + ) + uvicorn.run( + router.app, + host=args.host, + port=args.port, + log_level=getattr(args, "log_level", "info"), + ) + + +def _add_router_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Router bind address." + ) + parser.add_argument("--port", type=int, default=30080, help="Router port.") + parser.add_argument( + "--worker-urls", nargs="*", default=[], help="Initial diffusion worker URLs." + ) + parser.add_argument( + "--max-connections", + type=int, + default=100, + help="Max concurrent connections to workers.", + ) + parser.add_argument( + "--timeout", + type=float, + default=120.0, + help="Router-to-worker request timeout in seconds.", + ) + parser.add_argument( + "--health-check-interval", + type=int, + default=10, + help="Seconds between health checks.", + ) + parser.add_argument( + "--health-check-failure-threshold", + type=int, + default=3, + help="Consecutive failures before quarantine.", + ) + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm.", + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose logging." + ) + parser.add_argument( + "--log-level", type=str, default="info", help="Uvicorn log level." + ) + + +def _handle_router(args: argparse.Namespace) -> int: + _run_router_server( + args, worker_urls=list(args.worker_urls), log_prefix="[sglang-d-router]" + ) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="sglang-d-router", + description="SGLang diffusion router CLI.", + ) + _add_router_args(parser) + parser.set_defaults(handler=_handle_router) + return parser + + +def run_cli(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + handler = args.handler + return handler(args) + + +def main(argv: list[str] | None = None) -> int: + try: + return run_cli(argv) + except KeyboardInterrupt: + return 130 + except Exception as exc: + print(f"[sglang-d-router] error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/sglang_diffusion_routing/router/__init__.py b/src/sglang_diffusion_routing/router/__init__.py new file mode 100644 index 0000000..a699d0d --- /dev/null +++ b/src/sglang_diffusion_routing/router/__init__.py @@ -0,0 +1,5 @@ +"""Router implementations for sglang diffusion routing.""" + +from sglang_diffusion_routing.router.diffusion_router import DiffusionRouter + +__all__ = ["DiffusionRouter"] diff --git a/src/sglang_diffusion_routing/router/diffusion_router.py b/src/sglang_diffusion_routing/router/diffusion_router.py new file mode 100644 index 0000000..c2b9c87 --- /dev/null +++ b/src/sglang_diffusion_routing/router/diffusion_router.py @@ -0,0 +1,436 @@ +# This module is derived from radixark/miles#544. +# See README.md for full acknowledgment. + +import asyncio +import ipaddress +import json +import logging +import random +from urllib.parse import urlsplit, urlunsplit + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from starlette.responses import Response + +logger = logging.getLogger(__name__) + +_METADATA_HOSTS = {"169.254.169.254", "metadata.google.internal"} + + +class DiffusionRouter: + def __init__(self, args, verbose: bool = False): + """Initialize the router for load-balancing sglang-diffusion workers.""" + self.args = args + self.verbose = verbose + + self.app = FastAPI() + self.app.add_event_handler("startup", self._start_background_health_check) + self.app.add_event_handler("shutdown", self._shutdown) + + # URL -> active request count + self.worker_request_counts: dict[str, int] = {} + # URL -> consecutive health check failures + self.worker_failure_counts: dict[str, int] = {} + # quarantined workers excluded from routing + self.dead_workers: set[str] = set() + self._health_task: asyncio.Task | None = None + + self.routing_algorithm = getattr(args, "routing_algorithm", "least-request") + self._rr_index = 0 + + max_connections = getattr(args, "max_connections", 100) + timeout = getattr(args, "timeout", 120.0) + if timeout is None: + timeout = 120.0 + + self.client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max_connections), + timeout=httpx.Timeout(timeout), + ) + + self._setup_routes() + + def _setup_routes(self) -> None: + self.app.post("/add_worker")(self.add_worker) + self.app.get("/list_workers")(self.list_workers) + self.app.get("/health")(self.health) + self.app.get("/health_workers")(self.health_workers) + self.app.post("/generate")(self.generate) + self.app.post("/generate_video")(self.generate_video) + self.app.post("/update_weights_from_disk")(self.update_weights_from_disk) + self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])( + self.proxy + ) + + async def _start_background_health_check(self) -> None: + if self._health_task is None or self._health_task.done(): + self._health_task = asyncio.create_task(self._health_check_loop()) + + async def _shutdown(self) -> None: + if self._health_task is not None: + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + pass + await self.client.aclose() + + async def _check_worker_health(self, url: str) -> tuple[str, bool]: + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug( + "[diffusion-router] Worker %s unhealthy (status %s)", + url, + response.status_code, + ) + except Exception as exc: + logger.debug( + "[diffusion-router] Worker %s health check failed: %s", url, exc + ) + return url, False + + async def _health_check_loop(self) -> None: + """Background loop to monitor worker health and quarantine failing workers.""" + interval = getattr(self.args, "health_check_interval", 10) + threshold = getattr(self.args, "health_check_failure_threshold", 3) + + while True: + try: + await asyncio.sleep(interval) + + urls = [ + u for u in self.worker_request_counts if u not in self.dead_workers + ] + if not urls: + continue + + results = await asyncio.gather( + *(self._check_worker_health(url) for url in urls) + ) + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + if failures >= threshold: + logger.warning( + "[diffusion-router] Worker %s failed %s consecutive checks. Marking DEAD.", + url, + threshold, + ) + self.dead_workers.add(url) + else: + self.worker_failure_counts[url] = 0 + + healthy = len(self.worker_request_counts) - len(self.dead_workers) + logger.debug( + "[diffusion-router] Health check complete. %s workers healthy.", + healthy, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error( + "[diffusion-router] Unexpected error in health check loop: %s", + exc, + exc_info=True, + ) + await asyncio.sleep(5) + + def _use_url(self) -> str: + """Select a worker URL based on the configured routing algorithm.""" + if not self.worker_request_counts: + raise RuntimeError("No workers registered in the pool") + + valid_workers = [ + w for w in self.worker_request_counts if w not in self.dead_workers + ] + if not valid_workers: + raise RuntimeError("No healthy workers available in the pool") + + if self.routing_algorithm == "round-robin": + url = valid_workers[self._rr_index % len(valid_workers)] + self._rr_index = (self._rr_index + 1) % len(valid_workers) + elif self.routing_algorithm == "random": + url = random.choice(valid_workers) + else: + url = min(valid_workers, key=self.worker_request_counts.get) + + self.worker_request_counts[url] += 1 + return url + + def _finish_url(self, url: str) -> None: + """Mark the request to the given URL as finished.""" + if url not in self.worker_request_counts: + logger.error("[diffusion-router] URL %s not recognized in _finish_url", url) + return + + next_count = self.worker_request_counts[url] - 1 + if next_count < 0: + logger.error( + "[diffusion-router] URL %s count went negative; clamping to zero", url + ) + next_count = 0 + self.worker_request_counts[url] = next_count + + def _build_proxy_response( + self, content: bytes, status_code: int, headers: dict + ) -> Response: + """ + Build an HTTP response from proxied bytes. + + If the payload is small and valid JSON, return JSONResponse. + Otherwise, return raw bytes to avoid expensive JSON re-encoding. + """ + content_type = headers.get("content-type", "") + max_json_reencode_bytes = 256 * 1024 + if len(content) <= max_json_reencode_bytes: + try: + data = json.loads(content) + return JSONResponse( + content=data, status_code=status_code, headers=headers + ) + except Exception: + pass + + return Response( + content=content, + status_code=status_code, + headers=headers, + media_type=content_type, + ) + + async def _forward_to_worker(self, request: Request, path: str) -> Response: + """Forward a request to a selected worker and return the response.""" + try: + worker_url = self._use_url() + except RuntimeError as exc: + return JSONResponse(status_code=503, content={"error": str(exc)}) + + try: + query = request.url.query + url = ( + f"{worker_url}/{path}" if not query else f"{worker_url}/{path}?{query}" + ) + body = await request.body() + headers = dict(request.headers) + if body is not None: + headers = { + k: v + for k, v in headers.items() + if k.lower() not in ("content-length", "transfer-encoding") + } + + response = await self.client.request( + request.method, url, content=body, headers=headers + ) + content = await response.aread() + resp_headers = self._sanitize_response_headers(response.headers) + return self._build_proxy_response( + content, response.status_code, resp_headers + ) + except Exception as exc: + logger.error( + "[diffusion-router] Failed to forward request to %s: %s", + worker_url, + exc, + ) + return JSONResponse( + status_code=502, content={"error": f"Worker request failed: {exc}"} + ) + finally: + self._finish_url(worker_url) + + async def _broadcast_to_workers( + self, path: str, body: bytes, headers: dict + ) -> list[dict]: + """Send a request to all healthy workers and collect results.""" + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + return [] + + async def _send(worker_url: str) -> dict: + try: + response = await self.client.post( + f"{worker_url}/{path}", content=body, headers=headers + ) + content = await response.aread() + return { + "worker_url": worker_url, + "status_code": response.status_code, + "body": self._try_decode_json(content), + } + except Exception as exc: + return { + "worker_url": worker_url, + "status_code": 502, + "body": {"error": str(exc)}, + } + + return await asyncio.gather(*(_send(u) for u in urls)) + + @staticmethod + def _try_decode_json(content: bytes): + try: + return json.loads(content) + except Exception: + return {"raw": content.decode("utf-8", errors="replace")} + + @staticmethod + def _sanitize_response_headers(headers) -> dict: + """Remove hop-by-hop and encoding headers that no longer match buffered content.""" + hop_by_hop = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + } + dropped = {"content-length", "content-encoding"} + return { + k: v for k, v in headers.items() if k.lower() not in hop_by_hop | dropped + } + + @staticmethod + def _normalize_worker_url(url: str) -> str: + if not isinstance(url, str): + raise ValueError("worker_url must be a string") + + raw = url.strip() + if not raw: + raise ValueError("worker_url cannot be empty") + + parsed = urlsplit(raw) + if parsed.scheme not in {"http", "https"}: + raise ValueError("worker_url must start with http:// or https://") + if not parsed.netloc: + raise ValueError("worker_url must include host and port") + if parsed.username or parsed.password: + raise ValueError("worker_url must not include user credentials") + if parsed.query or parsed.fragment: + raise ValueError("worker_url must not include query or fragment") + if parsed.path not in ("", "/"): + raise ValueError("worker_url path is not allowed") + + hostname = (parsed.hostname or "").lower() + if hostname in _METADATA_HOSTS: + raise ValueError("worker_url host is blocked") + parsed_ip = None + try: + parsed_ip = ipaddress.ip_address(hostname) + except ValueError: + # Non-IP hostname + pass + if parsed_ip is not None and parsed_ip.is_link_local: + raise ValueError("link-local worker_url hosts are blocked") + + if parsed.port is None: + normalized_netloc = hostname + elif ":" in hostname and not hostname.startswith("["): + normalized_netloc = f"[{hostname}]:{parsed.port}" + else: + normalized_netloc = f"{hostname}:{parsed.port}" + + normalized = urlunsplit((parsed.scheme, normalized_netloc, "", "", "")) + return normalized.rstrip("/") + + async def generate(self, request: Request): + """Route image generation to /v1/images/generations.""" + return await self._forward_to_worker(request, "v1/images/generations") + + async def generate_video(self, request: Request): + """Route video generation to /v1/videos.""" + return await self._forward_to_worker(request, "v1/videos") + + async def health(self, request: Request): + """Aggregated health status: healthy if at least one worker is alive.""" + total = len(self.worker_request_counts) + dead = len(self.dead_workers) + healthy = total - dead + status = "healthy" if healthy > 0 else "unhealthy" + code = 200 if healthy > 0 else 503 + return JSONResponse( + status_code=code, + content={ + "status": status, + "healthy_workers": healthy, + "total_workers": total, + }, + ) + + async def health_workers(self, request: Request): + """Per-worker health and load information.""" + workers = [] + for url, count in self.worker_request_counts.items(): + workers.append( + { + "url": url, + "active_requests": count, + "is_dead": url in self.dead_workers, + "consecutive_failures": self.worker_failure_counts.get(url, 0), + } + ) + return JSONResponse(content={"workers": workers}) + + async def update_weights_from_disk(self, request: Request): + """Broadcast weight reload to all healthy workers.""" + body = await request.body() + headers = dict(request.headers) + results = await self._broadcast_to_workers( + "update_weights_from_disk", body, headers + ) + return JSONResponse(content={"results": results}) + + def register_worker(self, url: str) -> None: + """Register a worker URL if not already known.""" + normalized_url = self._normalize_worker_url(url) + if normalized_url not in self.worker_request_counts: + self.worker_request_counts[normalized_url] = 0 + self.worker_failure_counts[normalized_url] = 0 + if self.verbose: + print(f"[diffusion-router] Added new worker: {normalized_url}") + + async def add_worker(self, request: Request): + """Register a new diffusion worker (query: ?url=... or JSON body).""" + worker_url = request.query_params.get("url") or request.query_params.get( + "worker_url" + ) + if not worker_url: + body = await request.body() + try: + payload = json.loads(body) if body else {} + except json.JSONDecodeError: + return JSONResponse( + status_code=400, content={"error": "Invalid JSON body"} + ) + worker_url = payload.get("url") or payload.get("worker_url") + + if not worker_url: + return JSONResponse( + status_code=400, + content={ + "error": "worker_url is required (use query ?url=... or JSON body)" + }, + ) + + try: + self.register_worker(worker_url) + except ValueError as exc: + return JSONResponse(status_code=400, content={"error": str(exc)}) + return { + "status": "success", + "worker_urls": list(self.worker_request_counts.keys()), + } + + async def list_workers(self, request: Request): + """List all registered workers.""" + return {"urls": list(self.worker_request_counts.keys())} + + async def proxy(self, request: Request, path: str): + """Catch-all: forward unmatched requests to a selected worker.""" + return await self._forward_to_worker(request, path) diff --git a/tests/benchmarks/diffusion_router/bench_router.py b/tests/benchmarks/diffusion_router/bench_router.py new file mode 100644 index 0000000..59a16ac --- /dev/null +++ b/tests/benchmarks/diffusion_router/bench_router.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Launch diffusion workers and router, then run a serving benchmark. + +Example: + python tests/benchmarks/diffusion_router/bench_router.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 20 \ + --max-concurrency 4 +""" + +from __future__ import annotations + +import argparse +import json +import os +import shlex +import signal +import socket +import subprocess +import sys +import time +from collections.abc import Iterable +from pathlib import Path + +import httpx + + +def _repo_root() -> Path: + # tests/benchmarks/diffusion_router/bench_router.py -> repo root + return Path(__file__).resolve().parents[3] + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def _infer_client_host(host: str) -> str: + if host in ("0.0.0.0", "::"): + return "127.0.0.1" + return host + + +def _wait_for_health( + url: str, + timeout: int, + label: str, + proc: subprocess.Popen | None = None, +) -> None: + start = time.time() + last_print = 0.0 + while True: + elapsed = time.time() - start + + # Fail fast if a managed process exits unexpectedly. + if proc is not None and proc.poll() is not None: + raise RuntimeError( + f"{label} process exited with code {proc.returncode}. " + "Run the command directly to inspect startup errors." + ) + + try: + resp = httpx.get(f"{url}/health", timeout=1.0) + if resp.status_code == 200: + print(f" [bench] {label} is healthy ({elapsed:.0f}s)", flush=True) + return + except httpx.HTTPError: + pass + + if elapsed - last_print >= 30: + print( + f" [bench] Still waiting for {label}... ({elapsed:.0f}s elapsed)", + flush=True, + ) + last_print = elapsed + + if elapsed > timeout: + raise TimeoutError(f"Timed out waiting for {label} at {url}.") + time.sleep(1) + + +def _normalize_connect_host(host: str) -> str: + if host in ("0.0.0.0", "::"): + return "127.0.0.1" + if host == "localhost": + return "127.0.0.1" + return host + + +def _is_port_available(host: str, port: int) -> bool: + connect_host = _normalize_connect_host(host) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.5) + return sock.connect_ex((connect_host, port)) != 0 + + +def _reserve_available_port( + host: str, preferred_port: int, used_ports: set[int] +) -> int: + if preferred_port < 1 or preferred_port > 65535: + raise ValueError(f"Invalid port: {preferred_port}") + + for port in range(preferred_port, 65536): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + for port in range(1024, preferred_port): + if port in used_ports: + continue + if _is_port_available(host, port): + used_ports.add(port) + return port + + raise RuntimeError( + f"Unable to reserve a free port for host {host}. Preferred start={preferred_port}." + ) + + +def _parse_gpu_id_list(raw: str) -> list[str]: + return [item.strip() for item in raw.split(",") if item.strip()] + + +def _detect_gpu_count() -> int: + try: + import torch + + return int(torch.cuda.device_count()) + except Exception: + return 0 + + +def _resolve_gpu_pool( + args: argparse.Namespace, env: dict[str, str] +) -> list[str] | None: + if args.worker_gpu_ids: + return [str(x) for x in args.worker_gpu_ids] + + visible = env.get("CUDA_VISIBLE_DEVICES", "") + if visible: + parsed = _parse_gpu_id_list(visible) + if parsed: + return parsed + + gpu_count = _detect_gpu_count() + if gpu_count > 0: + return [str(i) for i in range(gpu_count)] + return None + + +def _terminate_all(processes: Iterable[subprocess.Popen]) -> None: + procs = list(processes) + + def _signal_group(proc: subprocess.Popen, sig: int) -> None: + try: + os.killpg(proc.pid, sig) + except ProcessLookupError: + pass + except Exception: + if proc.poll() is None: + try: + os.kill(proc.pid, sig) + except ProcessLookupError: + pass + + for proc in procs: + _signal_group(proc, signal.SIGTERM) + + for proc in procs: + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + _signal_group(proc, signal.SIGKILL) + + for proc in procs: + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + +def _build_pythonpath_with_src(repo_root: Path, env: dict[str, str]) -> dict[str, str]: + out = dict(env) + src_dir = str(repo_root / "src") + old = out.get("PYTHONPATH") + out["PYTHONPATH"] = src_dir if not old else f"{src_dir}:{old}" + return out + + +def _build_bench_command(args: argparse.Namespace, base_url: str) -> list[str]: + cmd = [ + sys.executable, + "-m", + "sglang.bench_serving", + "--backend", + "sglang", + "--base-url", + base_url, + "--model", + args.model, + "--dataset-name", + args.dataset, + "--num-prompts", + str(args.num_prompts), + "--max-concurrency", + str(args.max_concurrency), + "--request-rate", + str(args.request_rate), + "--log-level", + args.log_level, + ] + + if args.dataset_path: + cmd += ["--dataset-path", args.dataset_path] + if args.task: + cmd += ["--task", args.task] + if args.width: + cmd += ["--width", str(args.width)] + if args.height: + cmd += ["--height", str(args.height)] + if args.num_frames: + cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + cmd += ["--fps", str(args.fps)] + if args.output_file: + cmd += ["--output-file", args.output_file] + if args.disable_tqdm: + cmd.append("--disable-tqdm") + if args.bench_extra_args: + cmd += shlex.split(args.bench_extra_args) + + return cmd + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark sglang-d-router with sglang bench_serving." + ) + parser.add_argument( + "--model", type=str, required=True, help="Diffusion model HF ID or local path." + ) + parser.add_argument( + "--router-host", type=str, default="127.0.0.1", help="Router bind host." + ) + parser.add_argument("--router-port", type=int, default=30080, help="Router port.") + parser.add_argument( + "--routing-algorithm", + type=str, + default="least-request", + choices=["least-request", "round-robin", "random"], + help="Load-balancing algorithm for the router.", + ) + parser.add_argument( + "--router-verbose", action="store_true", help="Enable router verbose logging." + ) + parser.add_argument("--router-max-connections", type=int, default=100) + parser.add_argument("--router-timeout", type=float, default=120.0) + parser.add_argument("--router-health-check-interval", type=int, default=10) + parser.add_argument("--router-health-check-failure-threshold", type=int, default=3) + parser.add_argument( + "--router-extra-args", + type=str, + default="", + help="Extra args for the router CLI command.", + ) + + parser.add_argument( + "--worker-host", type=str, default="127.0.0.1", help="Worker bind host." + ) + parser.add_argument( + "--worker-urls", nargs="*", default=[], help="Existing worker URLs to use." + ) + parser.add_argument( + "--num-workers", type=int, default=1, help="Number of workers to launch." + ) + parser.add_argument( + "--worker-base-port", + type=int, + default=10090, + help="Base port for launched workers.", + ) + parser.add_argument( + "--worker-port-stride", + type=int, + default=2, + help="Port increment between launched workers. Keep >=2 to avoid port collisions.", + ) + parser.add_argument( + "--worker-master-port-base", + type=int, + default=30005, + help="Base torch distributed master port for launched workers.", + ) + parser.add_argument( + "--worker-scheduler-port-base", + type=int, + default=5555, + help="Base scheduler port for launched workers.", + ) + parser.add_argument( + "--worker-internal-port-stride", + type=int, + default=1000, + help="Stride used between workers for master/scheduler base ports.", + ) + parser.add_argument( + "--num-gpus-per-worker", type=int, default=1, help="GPUs per worker." + ) + parser.add_argument( + "--worker-gpu-ids", + nargs="*", + default=None, + help=( + "Optional GPU IDs/UUIDs for launched workers. They are consumed in order, " + "in groups of --num-gpus-per-worker." + ), + ) + parser.add_argument( + "--worker-extra-args", + type=str, + default="", + help="Extra args for `sglang serve`.", + ) + parser.add_argument( + "--skip-workers", action="store_true", help="Do not launch workers." + ) + + parser.add_argument( + "--dataset", type=str, default="random", choices=["vbench", "random"] + ) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--output-file", type=str, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument( + "--bench-extra-args", type=str, default="", help="Extra args for bench_serving." + ) + parser.add_argument( + "--wait-timeout", + type=int, + default=1200, + help="Seconds to wait for services health.", + ) + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + repo_root = _repo_root() + managed_processes: list[subprocess.Popen] = [] + launched_workers: list[tuple[str, subprocess.Popen]] = [] + worker_urls = list(args.worker_urls) + used_ports: set[int] = set() + + try: + import sglang # noqa: F401 + except ImportError as exc: + raise RuntimeError( + "sglang is not installed.\n" + 'Install with: uv pip install "sglang[diffusion]" --prerelease=allow' + ) from exc + + try: + if args.num_workers < 0: + raise ValueError("--num-workers must be >= 0") + if args.num_gpus_per_worker < 1: + raise ValueError("--num-gpus-per-worker must be >= 1") + + if not args.skip_workers: + host_for_url = _infer_client_host(args.worker_host) + worker_env = os.environ.copy() + gpu_pool = _resolve_gpu_pool(args, worker_env) + + needed = args.num_workers * args.num_gpus_per_worker + if gpu_pool and len(gpu_pool) < needed: + raise RuntimeError( + f"Not enough GPUs for requested workers. Need {needed}, found {len(gpu_pool)} in pool {gpu_pool}." + ) + + for i in range(args.num_workers): + preferred_worker_port = ( + args.worker_base_port + i * args.worker_port_stride + ) + worker_port = _reserve_available_port( + args.worker_host, preferred_worker_port, used_ports + ) + + preferred_master = ( + args.worker_master_port_base + i * args.worker_internal_port_stride + ) + preferred_scheduler = ( + args.worker_scheduler_port_base + + i * args.worker_internal_port_stride + ) + master_port = _reserve_available_port( + "127.0.0.1", preferred_master, used_ports + ) + scheduler_port = _reserve_available_port( + "127.0.0.1", preferred_scheduler, used_ports + ) + + cmd = [ + "sglang", + "serve", + "--model-path", + args.model, + "--num-gpus", + str(args.num_gpus_per_worker), + "--host", + args.worker_host, + "--port", + str(worker_port), + "--master-port", + str(master_port), + "--scheduler-port", + str(scheduler_port), + ] + if args.worker_extra_args: + cmd += shlex.split(args.worker_extra_args) + + env = dict(worker_env) + if gpu_pool: + start = i * args.num_gpus_per_worker + stop = start + args.num_gpus_per_worker + env["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_pool[start:stop]) + + worker_url = f"http://{host_for_url}:{worker_port}" + print(f"[run] {' '.join(shlex.quote(x) for x in cmd)}", flush=True) + proc = subprocess.Popen( + cmd, + cwd=repo_root, + env=env, + start_new_session=True, + ) + managed_processes.append(proc) + launched_workers.append((worker_url, proc)) + worker_urls.append(worker_url) + + if not worker_urls: + raise RuntimeError( + "No workers available. Use --worker-urls or launch workers by removing --skip-workers." + ) + + for url, proc in launched_workers: + _wait_for_health(url, args.wait_timeout, f"worker {url}", proc=proc) + for url in worker_urls: + if url not in {u for u, _ in launched_workers}: + _wait_for_health(url, args.wait_timeout, f"external worker {url}") + + router_env = _build_pythonpath_with_src(repo_root, os.environ.copy()) + router_cmd = [ + sys.executable, + "-m", + "sglang_diffusion_routing", + "--host", + args.router_host, + "--port", + str(args.router_port), + "--worker-urls", + *worker_urls, + "--routing-algorithm", + args.routing_algorithm, + "--max-connections", + str(args.router_max_connections), + "--timeout", + str(args.router_timeout), + "--health-check-interval", + str(args.router_health_check_interval), + "--health-check-failure-threshold", + str(args.router_health_check_failure_threshold), + "--log-level", + args.log_level.lower(), + ] + if args.router_verbose: + router_cmd.append("--verbose") + if args.router_extra_args: + router_cmd += shlex.split(args.router_extra_args) + + print(f"[run] {' '.join(shlex.quote(x) for x in router_cmd)}", flush=True) + router_proc = subprocess.Popen( + router_cmd, + cwd=repo_root, + env=router_env, + start_new_session=True, + ) + managed_processes.append(router_proc) + + router_client_host = _infer_client_host(args.router_host) + router_url = f"http://{router_client_host}:{args.router_port}" + _wait_for_health( + router_url, args.wait_timeout, f"router {router_url}", proc=router_proc + ) + + bench_cmd = _build_bench_command(args, router_url) + print(f"[run] {' '.join(shlex.quote(x) for x in bench_cmd)}", flush=True) + rc = subprocess.call(bench_cmd, cwd=repo_root) + + if rc == 0 and args.output_file: + out = Path(args.output_file) + if out.exists(): + try: + data = json.loads(out.read_text()) + print( + "[bench] result summary:" + f" throughput_qps={data.get('throughput_qps')}," + f" latency_mean={data.get('latency_mean')}," + f" latency_p99={data.get('latency_p99')}", + flush=True, + ) + except Exception: + pass + return rc + finally: + if managed_processes: + _terminate_all(managed_processes) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/benchmarks/diffusion_router/bench_routing_algorithms.py b/tests/benchmarks/diffusion_router/bench_routing_algorithms.py new file mode 100644 index 0000000..ee059ae --- /dev/null +++ b/tests/benchmarks/diffusion_router/bench_routing_algorithms.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +""" +Compare routing algorithms by running bench_router.py for each one and +collecting results in JSON and CSV outputs. + +Example: + python tests/benchmarks/diffusion_router/bench_routing_algorithms.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --num-workers 2 \ + --num-prompts 10 \ + --max-concurrency 2 +""" + +from __future__ import annotations + +import argparse +import csv +import json +import shlex +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +ALL_ALGORITHMS = ["least-request", "round-robin", "random"] +BASELINE = "random" + + +def _require_non_empty_model(model: str) -> str: + normalized = model.strip() + if not normalized: + raise ValueError( + "--model must be a non-empty model ID/path. " + "Detected an empty value, which often means a shell variable such as " + "$MODEL was unset." + ) + return normalized + + +def _pct_delta(value: float | int | str, baseline: float | int | str) -> float | str: + if not isinstance(value, (int, float)): + return "" + if not isinstance(baseline, (int, float)): + return "" + if baseline == 0: + return "" + return ((value - baseline) / abs(baseline)) * 100.0 + + +def _format_num(value: object, width: int = 10) -> str: + if isinstance(value, (int, float)): + return ( + f"{value:>{width}.4f}" if isinstance(value, float) else f"{value:>{width}d}" + ) + return f"{str(value):>{width}}" + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Compare routing algorithms by running bench_router.py for each." + ) + parser.add_argument( + "--model", type=str, required=True, help="Diffusion model HF ID or local path." + ) + parser.add_argument( + "--algorithms", + nargs="+", + default=ALL_ALGORITHMS, + choices=ALL_ALGORITHMS, + help="Algorithms to compare (default: all).", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory to store result artifacts.", + ) + + # Pass-through arguments for bench_router.py + parser.add_argument("--router-host", type=str, default="127.0.0.1") + parser.add_argument("--router-port", type=int, default=30080) + parser.add_argument("--router-verbose", action="store_true") + parser.add_argument("--router-max-connections", type=int, default=100) + parser.add_argument("--router-timeout", type=float, default=120.0) + parser.add_argument("--router-health-check-interval", type=int, default=10) + parser.add_argument("--router-health-check-failure-threshold", type=int, default=3) + parser.add_argument("--router-extra-args", type=str, default="") + parser.add_argument("--worker-host", type=str, default="127.0.0.1") + parser.add_argument("--worker-urls", nargs="*", default=[]) + parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument("--worker-base-port", type=int, default=10090) + parser.add_argument("--worker-port-stride", type=int, default=2) + parser.add_argument("--worker-master-port-base", type=int, default=30005) + parser.add_argument("--worker-scheduler-port-base", type=int, default=5555) + parser.add_argument("--worker-internal-port-stride", type=int, default=1000) + parser.add_argument("--num-gpus-per-worker", type=int, default=1) + parser.add_argument("--worker-gpu-ids", nargs="*", default=None) + parser.add_argument("--worker-extra-args", type=str, default="") + parser.add_argument("--skip-workers", action="store_true") + parser.add_argument( + "--dataset", type=str, default="random", choices=["vbench", "random"] + ) + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--max-concurrency", type=int, default=1) + parser.add_argument("--request-rate", type=float, default=float("inf")) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--width", type=int, default=None) + parser.add_argument("--height", type=int, default=None) + parser.add_argument("--num-frames", type=int, default=None) + parser.add_argument("--fps", type=int, default=None) + parser.add_argument("--disable-tqdm", action="store_true") + parser.add_argument("--log-level", type=str, default="INFO") + parser.add_argument("--bench-extra-args", type=str, default="") + parser.add_argument("--wait-timeout", type=int, default=1200) + + args = parser.parse_args() + args.model = _require_non_empty_model(args.model) + + script_dir = Path(__file__).resolve().parent + bench_router_script = script_dir / "bench_router.py" + if not bench_router_script.exists(): + raise RuntimeError(f"Missing benchmark script: {bench_router_script}") + + output_dir = ( + Path(args.output_dir) + if args.output_dir + else script_dir + / "outputs" + / f"routing_algo_compare_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + ) + output_dir.mkdir(parents=True, exist_ok=True) + + results: dict[str, dict] = {} + py = sys.executable + + for algo in args.algorithms: + print(f"\n{'=' * 72}", flush=True) + print(f"[bench] Running routing algorithm: {algo}", flush=True) + print(f"{'=' * 72}\n", flush=True) + + out_file = output_dir / f"bench_{algo}.json" + cmd = [ + py, + str(bench_router_script), + "--model", + args.model, + "--routing-algorithm", + algo, + "--router-host", + args.router_host, + "--router-port", + str(args.router_port), + "--router-max-connections", + str(args.router_max_connections), + "--router-timeout", + str(args.router_timeout), + "--router-health-check-interval", + str(args.router_health_check_interval), + "--router-health-check-failure-threshold", + str(args.router_health_check_failure_threshold), + "--worker-host", + args.worker_host, + "--num-workers", + str(args.num_workers), + "--worker-base-port", + str(args.worker_base_port), + "--worker-port-stride", + str(args.worker_port_stride), + "--worker-master-port-base", + str(args.worker_master_port_base), + "--worker-scheduler-port-base", + str(args.worker_scheduler_port_base), + "--worker-internal-port-stride", + str(args.worker_internal_port_stride), + "--num-gpus-per-worker", + str(args.num_gpus_per_worker), + "--dataset", + args.dataset, + "--num-prompts", + str(args.num_prompts), + "--max-concurrency", + str(args.max_concurrency), + "--request-rate", + str(args.request_rate), + "--wait-timeout", + str(args.wait_timeout), + "--log-level", + args.log_level, + "--output-file", + str(out_file), + ] + + if args.worker_urls: + cmd += ["--worker-urls", *args.worker_urls] + if args.worker_gpu_ids: + cmd += ["--worker-gpu-ids", *args.worker_gpu_ids] + if args.dataset_path: + cmd += ["--dataset-path", args.dataset_path] + if args.task: + cmd += ["--task", args.task] + if args.width: + cmd += ["--width", str(args.width)] + if args.height: + cmd += ["--height", str(args.height)] + if args.num_frames: + cmd += ["--num-frames", str(args.num_frames)] + if args.fps: + cmd += ["--fps", str(args.fps)] + if args.disable_tqdm: + cmd.append("--disable-tqdm") + if args.router_verbose: + cmd.append("--router-verbose") + if args.skip_workers: + cmd.append("--skip-workers") + if args.worker_extra_args: + cmd += ["--worker-extra-args", args.worker_extra_args] + if args.router_extra_args: + cmd += ["--router-extra-args", args.router_extra_args] + if args.bench_extra_args: + cmd += ["--bench-extra-args", args.bench_extra_args] + + print("[run]", " ".join(shlex.quote(x) for x in cmd), flush=True) + rc = subprocess.call(cmd) + if rc != 0: + print( + f"[warn] benchmark exited with code {rc} for algorithm '{algo}'", + flush=True, + ) + results[algo] = {"error": f"exit_code={rc}"} + continue + + if not out_file.exists(): + print(f"[warn] output file missing: {out_file}", flush=True) + results[algo] = {"error": "output_file_missing"} + continue + + try: + results[algo] = json.loads(out_file.read_text()) + except json.JSONDecodeError as exc: + print(f"[warn] invalid JSON in {out_file}: {exc}", flush=True) + results[algo] = {"error": f"json_parse_error={exc}"} + + metric_keys = ["throughput_qps", "latency_mean", "latency_median", "latency_p99"] + baseline = results.get(BASELINE, {}) + rows: list[dict[str, object]] = [] + + for algo in args.algorithms: + data = results.get(algo, {}) + if "error" in data: + row = { + "algorithm": algo, + "throughput_qps": "", + "latency_mean": "", + "latency_median": "", + "latency_p99": "", + "duration": "", + "completed_requests": "", + "failed_requests": "", + "throughput_qps_delta_pct": "", + "latency_mean_delta_pct": "", + "latency_median_delta_pct": "", + "latency_p99_delta_pct": "", + "error": data["error"], + } + rows.append(row) + continue + + row = { + "algorithm": algo, + "throughput_qps": data.get("throughput_qps", ""), + "latency_mean": data.get("latency_mean", ""), + "latency_median": data.get("latency_median", ""), + "latency_p99": data.get("latency_p99", ""), + "duration": data.get("duration", ""), + "completed_requests": data.get("completed_requests", ""), + "failed_requests": data.get("failed_requests", ""), + "error": "", + } + for key in metric_keys: + row[f"{key}_delta_pct"] = _pct_delta( + row.get(key, ""), baseline.get(key, "") + ) + rows.append(row) + + print(f"\n{'=' * 108}", flush=True) + print(f"[summary] Routing Algorithm Comparison (baseline: {BASELINE})", flush=True) + print(f"{'=' * 108}", flush=True) + header = ( + f"{'Algorithm':<16} {'Throughput':>12} {'TputDelta%':>12} " + f"{'MeanLat':>12} {'MeanDelta%':>12} " + f"{'P99Lat':>12} {'P99Delta%':>12} " + f"{'Done':>8} {'Fail':>8}" + ) + print(header, flush=True) + print("-" * len(header), flush=True) + for row in rows: + if row.get("error"): + print( + f"{row['algorithm']:<16} {'-':>12} {'-':>12} {'-':>12} {'-':>12} " + f"{'-':>12} {'-':>12} {'-':>8} {'-':>8} error={row['error']}", + flush=True, + ) + continue + print( + f"{row['algorithm']:<16} " + f"{_format_num(row['throughput_qps'], 12)} {_format_num(row['throughput_qps_delta_pct'], 12)} " + f"{_format_num(row['latency_mean'], 12)} {_format_num(row['latency_mean_delta_pct'], 12)} " + f"{_format_num(row['latency_p99'], 12)} {_format_num(row['latency_p99_delta_pct'], 12)} " + f"{_format_num(row['completed_requests'], 8)} {_format_num(row['failed_requests'], 8)}", + flush=True, + ) + + csv_path = output_dir / "routing_algorithm_comparison.csv" + fieldnames = list(rows[0].keys()) if rows else ["algorithm"] + with csv_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + summary_path = output_dir / "routing_algorithm_comparison.json" + summary_path.write_text( + json.dumps({"results": results, "rows": rows, "baseline": BASELINE}, indent=2) + ) + + print("\n[done] Artifacts:", flush=True) + print(f" - {csv_path}", flush=True) + print(f" - {summary_path}", flush=True) + + # Return non-zero if every algorithm failed. + failures = sum(1 for row in rows if row.get("error")) + return 1 if rows and failures == len(rows) else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..1ef2a85 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from unittest import mock + +from sglang_diffusion_routing.cli.main import build_parser, run_cli + + +class TestCLIParser: + def test_defaults(self): + args = build_parser().parse_args([]) + assert args.host == "0.0.0.0" + assert args.port == 30080 + assert args.worker_urls == [] + assert args.routing_algorithm == "least-request" + assert args.timeout == 120.0 + assert args.max_connections == 100 + assert args.health_check_interval == 10 + assert args.health_check_failure_threshold == 3 + assert args.verbose is False + assert args.log_level == "info" + + def test_parses_worker_urls(self): + args = build_parser().parse_args( + [ + "--host", + "127.0.0.1", + "--port", + "31000", + "--worker-urls", + "http://localhost:10090", + "http://localhost:10092", + "--routing-algorithm", + "round-robin", + "--verbose", + "--log-level", + "warning", + ] + ) + assert args.host == "127.0.0.1" + assert args.port == 31000 + assert args.worker_urls == ["http://localhost:10090", "http://localhost:10092"] + assert args.routing_algorithm == "round-robin" + assert args.verbose is True + assert args.log_level == "warning" + + +def test_run_cli_calls_router_runner(): + with mock.patch("sglang_diffusion_routing.cli.main._run_router_server") as mock_run: + code = run_cli(["--port", "30123", "--worker-urls", "http://localhost:10090"]) + assert code == 0 + mock_run.assert_called_once() + args = mock_run.call_args.args[0] + assert args.port == 30123 + assert args.worker_urls == ["http://localhost:10090"] + assert mock_run.call_args.kwargs["worker_urls"] == ["http://localhost:10090"] + assert mock_run.call_args.kwargs["log_prefix"] == "[sglang-d-router]" diff --git a/tests/unit/test_diffusion_router.py b/tests/unit/test_diffusion_router.py new file mode 100644 index 0000000..223835d --- /dev/null +++ b/tests/unit/test_diffusion_router.py @@ -0,0 +1,229 @@ +import asyncio +import json +from argparse import Namespace +from types import SimpleNamespace + +import pytest + +from sglang_diffusion_routing import DiffusionRouter + + +def make_router_args(**overrides) -> Namespace: + """Create a Namespace with default DiffusionRouter args, applying overrides.""" + defaults = dict( + host="127.0.0.1", + port=30080, + max_connections=100, + timeout=120.0, + routing_algorithm="least-request", + ) + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.fixture +def router_factory(): + """Factory fixture that creates routers and closes their clients at teardown.""" + created_routers: list[DiffusionRouter] = [] + + def _create( + workers: dict[str, int], + dead: set[str] | None = None, + **arg_overrides, + ) -> DiffusionRouter: + router = DiffusionRouter(make_router_args(**arg_overrides)) + router.worker_request_counts = dict(workers) + router.worker_failure_counts = {url: 0 for url in workers} + if dead: + router.dead_workers = set(dead) + created_routers.append(router) + return router + + yield _create + + for router in created_routers: + asyncio.run(router.client.aclose()) + + +class TestLeastRequest: + """Test the least-request (default) load-balancing algorithm.""" + + def test_selects_min_load(self, router_factory): + router = router_factory( + {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + ) + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}, + dead={"http://w2:8000"}, + ) + selected = router._use_url() + assert selected == "http://w1:8000" + assert router.worker_request_counts["http://w1:8000"] == 6 + + +class TestRoundRobin: + """Test the round-robin load-balancing algorithm.""" + + def test_cycles_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(6)] + workers = list(router.worker_request_counts.keys()) + expected = [workers[i % 3] for i in range(6)] + assert results == expected + for url in workers: + assert router.worker_request_counts[url] == 2 + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="round-robin", + ) + results = [router._use_url() for _ in range(4)] + assert "http://w2:8000" not in results + assert all(url in ("http://w1:8000", "http://w3:8000") for url in results) + + +class TestRandom: + """Test the random load-balancing algorithm.""" + + def test_selects_from_valid_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + routing_algorithm="random", + ) + seen = set() + for _ in range(30): + # Reset counts so they do not grow unbounded + for url in router.worker_request_counts: + router.worker_request_counts[url] = 0 + seen.add(router._use_url()) + assert seen == {"http://w1:8000", "http://w2:8000", "http://w3:8000"} + + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0, "http://w3:8000": 0}, + dead={"http://w2:8000"}, + routing_algorithm="random", + ) + for _ in range(20): + url = router._use_url() + assert url != "http://w2:8000" + router.worker_request_counts[url] -= 1 # reset increment + + +class TestErrorCases: + """Test error handling across all routing algorithms.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_no_workers(self, router_factory, algorithm): + router = router_factory({}, routing_algorithm=algorithm) + with pytest.raises(RuntimeError, match="No workers registered"): + router._use_url() + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_raises_when_all_dead(self, router_factory, algorithm): + router = router_factory( + {"http://w1:8000": 0, "http://w2:8000": 0}, + dead={"http://w1:8000", "http://w2:8000"}, + routing_algorithm=algorithm, + ) + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +class TestCountManagement: + """Test that _use_url / _finish_url correctly track active request counts.""" + + @pytest.mark.parametrize("algorithm", ["least-request", "round-robin", "random"]) + def test_increment_and_finish(self, router_factory, algorithm): + router = router_factory({"http://w1:8000": 0}, routing_algorithm=algorithm) + url = router._use_url() + assert router.worker_request_counts[url] == 1 + router._finish_url(url) + assert router.worker_request_counts[url] == 0 + + +class TestDefaults: + """Test default routing algorithm when the attribute is absent.""" + + def test_default_algorithm_is_least_request(self): + args = Namespace( + host="127.0.0.1", port=30080, max_connections=100, timeout=120.0 + ) + # args has no routing_algorithm attribute + router = DiffusionRouter(args) + try: + assert router.routing_algorithm == "least-request" + finally: + asyncio.run(router.client.aclose()) + + +class TestRegressions: + def test_forward_body_error_does_not_leak_request_count(self, router_factory): + router = router_factory({"http://w1:8000": 0}) + + class BrokenRequest: + method = "POST" + headers = {"content-type": "application/json"} + url = SimpleNamespace(query="") + + async def body(self): + raise RuntimeError("body read failed") + + response = asyncio.run( + router._forward_to_worker(BrokenRequest(), "v1/images/generations") + ) + assert response.status_code == 502 + assert router.worker_request_counts["http://w1:8000"] == 0 + + def test_register_worker_normalizes_duplicate_urls(self, router_factory): + router = router_factory({}) + router.register_worker("http://LOCALHOST:10090/") + router.register_worker("http://localhost:10090") + assert list(router.worker_request_counts.keys()) == ["http://localhost:10090"] + + def test_register_worker_rejects_metadata_host(self, router_factory): + router = router_factory({}) + with pytest.raises(ValueError, match="host is blocked"): + router.register_worker("http://169.254.169.254:80") + + def test_broadcast_to_workers_collects_per_worker_results(self, router_factory): + router = router_factory({"http://w1:8000": 0, "http://w2:8000": 0}) + + class FakeResponse: + def __init__(self, status_code: int, body: dict): + self.status_code = status_code + self._body = body + + async def aread(self) -> bytes: + return json.dumps(self._body).encode("utf-8") + + responses = { + "http://w1:8000/update_weights_from_disk": FakeResponse(200, {"ok": True}), + "http://w2:8000/update_weights_from_disk": FakeResponse( + 500, {"error": "bad worker"} + ), + } + + async def fake_post(url, content, headers): + del content, headers + return responses[url] + + router.client.post = fake_post # type: ignore[assignment] + result = asyncio.run( + router._broadcast_to_workers("update_weights_from_disk", b"{}", {}) + ) + assert len(result) == 2 + assert {item["worker_url"] for item in result} == { + "http://w1:8000", + "http://w2:8000", + } diff --git a/tests/unit/test_router_endpoints.py b/tests/unit/test_router_endpoints.py new file mode 100644 index 0000000..2983b8d --- /dev/null +++ b/tests/unit/test_router_endpoints.py @@ -0,0 +1,68 @@ +from argparse import Namespace + +from fastapi.testclient import TestClient + +from sglang_diffusion_routing import DiffusionRouter + + +def make_router_args(**overrides) -> Namespace: + defaults = dict( + host="127.0.0.1", + port=30080, + max_connections=100, + timeout=120.0, + routing_algorithm="least-request", + health_check_interval=3600, + health_check_failure_threshold=3, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def test_add_worker_normalizes_and_deduplicates(): + router = DiffusionRouter(make_router_args()) + with TestClient(router.app) as client: + first = client.post("/add_worker", params={"url": "http://LOCALHOST:10090/"}) + assert first.status_code == 200 + + second = client.post("/add_worker", params={"url": "http://localhost:10090"}) + assert second.status_code == 200 + payload = second.json() + assert payload["worker_urls"] == ["http://localhost:10090"] + + listed = client.get("/list_workers") + assert listed.status_code == 200 + assert listed.json()["urls"] == ["http://localhost:10090"] + + +def test_add_worker_rejects_blocked_metadata_host(): + router = DiffusionRouter(make_router_args()) + with TestClient(router.app) as client: + response = client.post( + "/add_worker", params={"url": "http://169.254.169.254:80"} + ) + assert response.status_code == 400 + assert "blocked" in response.json()["error"] + + +def test_update_weights_from_disk_returns_broadcast_results(): + router = DiffusionRouter(make_router_args()) + router.register_worker("http://localhost:10090") + + async def fake_broadcast(path: str, body: bytes, headers: dict): + assert path == "update_weights_from_disk" + assert body == b'{"model_path":"abc"}' + assert headers.get("content-type", "").startswith("application/json") + return [ + { + "worker_url": "http://localhost:10090", + "status_code": 200, + "body": {"ok": True}, + } + ] + + router._broadcast_to_workers = fake_broadcast # type: ignore[assignment] + with TestClient(router.app) as client: + response = client.post("/update_weights_from_disk", json={"model_path": "abc"}) + assert response.status_code == 200 + assert response.json()["results"][0]["status_code"] == 200 From f02fc8a218703ea3c7bf6912f68c488eee941a86 Mon Sep 17 00:00:00 2001 From: alphabetc1 <2508695655@qq.com> Date: Thu, 19 Feb 2026 10:14:57 +0800 Subject: [PATCH 2/2] feat: support lint ci --- .github/workflows/lint.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..eb85ecd --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,24 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install pre-commit + run: python -m pip install pre-commit + + - name: Run pre-commit checks + run: SKIP=no-commit-to-branch pre-commit run --all-files --show-diff-on-failure