Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions miles/router/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import asyncio
import json
import logging

import httpx
import uvicorn
Expand All @@ -9,6 +11,8 @@

from miles.utils.misc import load_function

logger = logging.getLogger(__name__)


def run_router(args):
"""
Expand All @@ -28,9 +32,14 @@ def __init__(self, args, verbose=False):
self.verbose = verbose

self.app = FastAPI()

# Worker information
self.worker_urls: dict[str, int] = {}
self.app.add_event_handler("startup", self._start_background_health_check)

# URL -> Active Request Count (load state)
self.worker_request_counts: dict[str, int] = {}
# URL -> Consecutive Failures
self.worker_failure_counts: dict[str, int] = {}
# Quarantined workers excluded from routing pool
self.dead_workers: set[str] = set()
self.max_weight_version = None

max_connections = getattr(args, "miles_router_max_connections", None)
Expand Down Expand Up @@ -63,9 +72,61 @@ def _setup_routes(self):
# Catch-all route for proxying to SGLang - must be registered LAST
self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy)

async def health_check(self, request: Request):
# TODO: do health check in background
pass
async def _start_background_health_check(self):
asyncio.create_task(self._health_check_loop())

async def _check_worker_health(self, url):
"""Encapsulated health check logic for better maintainability."""
try:
response = await self.client.get(f"{url}/health", timeout=5.0)
if response.status_code == 200:
return url, True
logger.debug(f"[miles-router] Worker {url} is unhealthy (Status: {response.status_code})")
except Exception as e:
logger.debug(f"[miles-router] Worker {url} health check failed: {e}")
return url, False

async def _health_check_loop(self):
"""Background loop to monitor worker health and adjust routing pool."""
interval = self.args.rollout_health_check_interval
threshold = self.args.miles_router_health_check_failure_threshold

while True:
try:
await asyncio.sleep(interval)

urls = [u for u in self.worker_request_counts if u not in self.dead_workers]
if not urls:
continue

results = await asyncio.gather(*(self._check_worker_health(url) for url in urls))

for url, is_healthy in results:
if not is_healthy:
failures = self.worker_failure_counts.get(url, 0) + 1
self.worker_failure_counts[url] = failures

if failures >= threshold:
logger.warning(
f"[miles-router] Worker {url} failed {threshold} consecutive health checks. Marking as DEAD."
)
self.dead_workers.add(url)
# TODO (chenyang): Connect back 'dead' workers requires a mechanism to sync
# model versions to avoid off-policy issues from stale weights, since these
# dead workers' parameters may not be refitted.
else:
self.worker_failure_counts[url] = 0

logger.debug(
f"[miles-router] Health check complete. {len(self.worker_request_counts) - len(self.dead_workers)} workers healthy."
)

except asyncio.CancelledError:
logger.warning("[miles-router] Background health check loop is being cancelled.")
raise
except Exception as e:
logger.error(f"[miles-router] Unexpected error in health check loop: {e}", exc_info=True)
await asyncio.sleep(5)

async def proxy(self, request: Request, path: str):
"""Proxy all other requests to the SGLang router"""
Expand Down Expand Up @@ -124,16 +185,17 @@ async def add_worker(self, request: Request):
)

# Add if new, keep a simple request count per worker
if worker_url not in self.worker_urls:
self.worker_urls[worker_url] = 0
if worker_url not in self.worker_request_counts:
self.worker_request_counts[worker_url] = 0
self.worker_failure_counts[worker_url] = 0
if self.verbose:
print(f"[miles-router] Added new worker: {worker_url}")

return {"status": "success", "worker_urls": self.worker_urls}
return {"status": "success", "worker_urls": self.worker_request_counts}

async def list_workers(self, request: Request):
"""List all registered workers"""
return {"urls": list(self.worker_urls.keys())}
return {"urls": list(self.worker_request_counts.keys())}

async def retrieve_from_text(self, request: Request):
"""Get token information from text input"""
Expand All @@ -158,19 +220,27 @@ async def retrieve_from_text(self, request: Request):
return result

def _use_url(self):
"""Select a worker URL using round-robin strategy"""
assert len(self.worker_urls) > 0, "No workers available"
"""Select worker URL with minimal active requests."""

if not self.dead_workers:
# Healthy path: select from all workers
url = min(self.worker_request_counts, key=self.worker_request_counts.get)
else:
# Degraded path: select from workers not in dead_workers
valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers)
try:
url = min(valid_workers, key=self.worker_request_counts.get)
except ValueError:
raise RuntimeError("No healthy workers available in the pool") from None

# get the url with mininal count
url = min(self.worker_urls, key=self.worker_urls.get)
self.worker_urls[url] += 1
self.worker_request_counts[url] += 1
return url

def _finish_url(self, url):
"""Mark the request to the given URL as finished"""
assert url in self.worker_urls, f"URL {url} not recognized"
self.worker_urls[url] -= 1
assert self.worker_urls[url] >= 0, f"URL {url} count went negative"
assert url in self.worker_request_counts, f"URL {url} not recognized"
self.worker_request_counts[url] -= 1
assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative"


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,12 @@ def add_router_arguments(parser):
default=None,
help="Max connections for MilesRouter HTTP client.",
)
parser.add_argument(
"--miles-router-health-check-failure-threshold",
type=int,
default=3,
help="Number of consecutive failures before marking a worker as unhealthy.",
)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
return parser

Expand Down