diff --git a/examples/diffusion_router/README.md b/examples/diffusion_router/README.md index c4af0e7b5..55c38c4d0 100644 --- a/examples/diffusion_router/README.md +++ b/examples/diffusion_router/README.md @@ -19,7 +19,7 @@ curl -X POST 'http://localhost:30080/add_worker?url=http://localhost:10090' | Method | Path | Description | |--------|------|-------------| | `POST` | `/generate` | Image generation (forwards to `/v1/images/generations`) | -| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos/generations`) | +| `POST` | `/generate_video` | Video generation (forwards to `/v1/videos`) | | `GET` | `/health` | Aggregated router health | | `GET` | `/health_workers` | Per-worker health and load info | | `POST` | `/add_worker` | Register a diffusion worker (`?url=...` or JSON body) | diff --git a/examples/diffusion_router/bench_router.py b/examples/diffusion_router/bench_router.py index a20d64ac6..70f78783c 100644 --- a/examples/diffusion_router/bench_router.py +++ b/examples/diffusion_router/bench_router.py @@ -446,7 +446,9 @@ def main() -> int: if args.bench_extra_args: bench_cmd += shlex.split(args.bench_extra_args) - return subprocess.call(bench_cmd, env=env) + bench_proc = subprocess.Popen(bench_cmd, env=env, start_new_session=True) + processes.append(bench_proc) + return bench_proc.wait() finally: _terminate_all(reversed(processes)) diff --git a/miles/router/diffusion_router.py b/miles/router/diffusion_router.py index d77671083..facc9c51d 100644 --- a/miles/router/diffusion_router.py +++ b/miles/router/diffusion_router.py @@ -137,9 +137,11 @@ def _use_url(self): def _finish_url(self, url): """Mark the request to the given URL as finished.""" - assert url in self.worker_request_counts, f"URL {url} not recognized" + if url not in self.worker_request_counts: + raise ValueError(f"URL {url} not recognized") self.worker_request_counts[url] -= 1 - assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" + if self.worker_request_counts[url] < 0: + raise RuntimeError(f"URL {url} count went negative") # ── Proxy helpers ──────────────────────────────────────────────── @@ -187,7 +189,11 @@ async def _forward_to_worker(self, request: Request, path: str) -> Response: try: response = await self.client.request(request.method, url, content=body, headers=headers) content = await response.aread() - finally: + except Exception as exc: + self._finish_url(worker_url) + logger.error(f"[diffusion-router] Failed to forward request to {worker_url}: {exc}") + return JSONResponse(status_code=502, content={"error": f"Worker request failed: {exc}"}) + else: self._finish_url(worker_url) resp_headers = self._sanitize_response_headers(response.headers) diff --git a/tests/fast/router/test_diffusion_router.py b/tests/fast/router/test_diffusion_router.py index 65a2df394..0d2d2a134 100644 --- a/tests/fast/router/test_diffusion_router.py +++ b/tests/fast/router/test_diffusion_router.py @@ -50,6 +50,15 @@ def test_selects_min_load(self, router_factory): assert selected == "http://w2:8000" assert router.worker_request_counts["http://w2:8000"] == 3 + def test_excludes_dead_workers(self, router_factory): + router = router_factory( + {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}, + dead={"http://w2:8000"}, + ) + selected = router._use_url() + assert selected == "http://w1:8000" + assert router.worker_request_counts["http://w1:8000"] == 6 + # ── Round-robin ──────────────────────────────────────────────────