Skip to content
Prev Previous commit
Next Next commit
remove __main__ block
  • Loading branch information
sniper35 committed Feb 8, 2026
commit 6111a94710415dc573da6a375a4aca50925a562a
6 changes: 1 addition & 5 deletions examples/diffusion_router/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ def main():

# Pre-register any workers specified on the command line
for url in args.worker_urls:
if url not in router.worker_request_counts:
router.worker_request_counts[url] = 0
router.worker_failure_counts[url] = 0
if args.verbose:
print(f"[demo] Pre-registered worker: {url}")
router.register_worker(url)

print(f"[demo] Starting diffusion router on {args.host}:{args.port}")
print(f"[demo] Workers: {list(router.worker_request_counts.keys()) or '(none — add via POST /add_worker)'}")
Expand Down
47 changes: 10 additions & 37 deletions miles/router/diffusion_router.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import argparse
import asyncio
import json
import logging
import random

import httpx
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.responses import Response

logger = logging.getLogger(__name__)


def run_diffusion_router(args):
"""Run the diffusion router with the specified configuration."""
router = DiffusionRouter(args)
uvicorn.run(router.app, host=args.host, port=args.port, log_level="info")


class DiffusionRouter:
def __init__(self, args, verbose=False):
"""Initialize the diffusion router for load-balancing across sglang-diffusion workers."""
Expand Down Expand Up @@ -264,8 +256,16 @@ async def update_weights_from_disk(self, request: Request):
results = await self._broadcast_to_workers("update_weights_from_disk", body, headers)
return JSONResponse(content={"results": results})

def register_worker(self, url: str) -> None:
"""Register a worker URL if not already known (sync, for startup use)."""
if url not in self.worker_request_counts:
self.worker_request_counts[url] = 0
self.worker_failure_counts[url] = 0
if self.verbose:
print(f"[diffusion-router] Added new worker: {url}")

async def add_worker(self, request: Request):
"""Register a new diffusion worker."""
"""Register a new diffusion worker (HTTP endpoint)."""
worker_url = request.query_params.get("url") or request.query_params.get("worker_url")

if not worker_url:
Expand All @@ -281,12 +281,7 @@ async def add_worker(self, request: Request):
status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"}
)

if worker_url not in self.worker_request_counts:
self.worker_request_counts[worker_url] = 0
self.worker_failure_counts[worker_url] = 0
if self.verbose:
print(f"[diffusion-router] Added new worker: {worker_url}")

self.register_worker(worker_url)
return {"status": "success", "worker_urls": list(self.worker_request_counts.keys())}

async def list_workers(self, request: Request):
Expand All @@ -296,25 +291,3 @@ async def list_workers(self, request: Request):
async def proxy(self, request: Request, path: str):
"""Catch-all: forward any unmatched request to the least-loaded worker."""
return await self._forward_to_worker(request, path)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Miles Diffusion Router")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=30080)
parser.add_argument("--worker-urls", nargs="*", default=[], help="Initial worker URLs to register")
parser.add_argument("--max-connections", type=int, default=100)
parser.add_argument("--timeout", type=float, default=None)
parser.add_argument("--health-check-interval", type=int, default=10)
parser.add_argument("--health-check-failure-threshold", type=int, default=3)
parser.add_argument("--routing-algorithm", type=str, default="least-request",
choices=["least-request", "round-robin", "random"])
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()

router = DiffusionRouter(args, verbose=args.verbose)
for url in args.worker_urls:
router.worker_request_counts[url] = 0
router.worker_failure_counts[url] = 0

uvicorn.run(router.app, host=args.host, port=args.port, log_level="info")