|
15 | 15 | import atexit |
16 | 16 | import json |
17 | 17 | from abc import abstractmethod |
| 18 | +from contextlib import asynccontextmanager |
| 19 | +from io import StringIO |
18 | 20 | from logging import Filter as LoggingFilter |
19 | 21 | from logging import LogRecord, getLogger |
20 | 22 | from os import getenv |
| 23 | +from pathlib import Path |
21 | 24 | from threading import Thread |
22 | 25 | from typing import Literal, Optional, Tuple, Type, Union, Unpack |
23 | 26 | from uuid import uuid4 |
24 | 27 |
|
25 | 28 | import requests |
26 | 29 | import uvicorn |
| 30 | +import yappi |
27 | 31 | from aiohttp import ClientResponse, ClientSession, ClientTimeout, DummyCookieJar, ServerDisconnectedError, TCPConnector |
28 | 32 | from aiohttp.client import _RequestOptions |
29 | 33 | from fastapi import FastAPI, Request, Response |
|
32 | 36 | from requests.exceptions import ConnectionError |
33 | 37 | from starlette.middleware.sessions import SessionMiddleware |
34 | 38 |
|
| 39 | +from nemo_gym import PARENT_DIR |
35 | 40 | from nemo_gym.config_types import ( |
36 | 41 | BaseRunServerInstanceConfig, |
37 | 42 | BaseServerConfig, |
|
50 | 55 |
|
51 | 56 |
|
52 | 57 | class GlobalAIOHTTPAsyncClientConfig(BaseModel): |
53 | | - global_aiohttp_connector_limit: int = 1000 |
54 | | - global_aiohttp_connector_limit_per_host: int = 100 |
| 58 | + global_aiohttp_connector_limit: int = 100 * 1024 |
| 59 | + global_aiohttp_connector_limit_per_host: int = 1024 |
55 | 60 |
|
56 | 61 |
|
57 | 62 | def get_global_aiohttp_client( |
@@ -123,7 +128,7 @@ async def request(method: str, url: str, **kwargs: Unpack[_RequestOptions]) -> C |
123 | 128 | await asyncio.sleep(0.5) |
124 | 129 | except Exception as e: |
125 | 130 | print( |
126 | | - f"""Hit an exception while making a request (try {num_tries}): {e} |
| 131 | + f"""Hit an exception while making a request (try {num_tries}): {type(e)}: {e} |
127 | 132 | Sleeping 0.5s and retrying... |
128 | 133 | """ |
129 | 134 | ) |
@@ -274,6 +279,20 @@ def load_config_from_global_config(cls) -> "BaseRunServerInstanceConfig": |
274 | 279 | return server_config |
275 | 280 |
|
276 | 281 |
|
| 282 | +class ProfilingMiddlewareInputConfig(BaseModel): |
| 283 | + # Relative to the Gym root dir. |
| 284 | + profiling_results_dirpath: Optional[str] = None |
| 285 | + |
| 286 | + |
| 287 | +class ProfilingMiddlewareConfig(ProfilingMiddlewareInputConfig): |
| 288 | + profiling_enabled: bool = False |
| 289 | + |
| 290 | + |
| 291 | +class UvicornLoggingConfig(BaseModel): |
| 292 | + # Default to False for regular use cases. |
| 293 | + uvicorn_logging_show_200_ok: bool = False |
| 294 | + |
| 295 | + |
277 | 296 | class SimpleServer(BaseServer): |
278 | 297 | server_client: ServerClient |
279 | 298 |
|
@@ -305,36 +324,86 @@ async def add_session_id(request: Request, call_next): # pragma: no cover |
305 | 324 | session_middleware_key = self.get_session_middleware_key() |
306 | 325 | app.add_middleware(SessionMiddleware, secret_key=session_middleware_key, session_cookie=session_middleware_key) |
307 | 326 |
|
| 327 | + def setup_profiling(self, app: FastAPI, profiling_config: ProfilingMiddlewareConfig) -> None: # pragma: no cover |
| 328 | + base_profile_dir = Path(PARENT_DIR) / profiling_config.profiling_results_dirpath |
| 329 | + server_profile_path = (base_profile_dir / self.get_session_middleware_key()).with_suffix(".log") |
| 330 | + |
| 331 | + base_profile_dir.mkdir(parents=True, exist_ok=True) |
| 332 | + |
| 333 | + main_app_lifespan = app.router.lifespan_context |
| 334 | + |
| 335 | + @asynccontextmanager |
| 336 | + async def lifespan_wrapper(app): |
| 337 | + yappi.set_clock_type("WALL") |
| 338 | + yappi.start() |
| 339 | + print(f"🔍 Enabled profiling for {self.config.name}") |
| 340 | + |
| 341 | + async with main_app_lifespan(app) as maybe_state: |
| 342 | + yield maybe_state |
| 343 | + |
| 344 | + print(f"🛑 Stopping profiler for {self.config.name}. Check {server_profile_path} for the metrics!") |
| 345 | + yappi.stop() |
| 346 | + |
| 347 | + buffer = StringIO() |
| 348 | + yappi.get_func_stats().print_all( |
| 349 | + out=buffer, |
| 350 | + columns={ |
| 351 | + 0: ("name", 200), |
| 352 | + 1: ("ncall", 10), |
| 353 | + 2: ("tsub", 8), |
| 354 | + 3: ("ttot", 8), |
| 355 | + 4: ("tavg", 8), |
| 356 | + }, |
| 357 | + ) |
| 358 | + |
| 359 | + buffer.seek(0) |
| 360 | + with open(server_profile_path, "w") as f: |
| 361 | + past_header = False |
| 362 | + for line in buffer: |
| 363 | + if not past_header or self.config.entrypoint in line: |
| 364 | + f.write(line) |
| 365 | + |
| 366 | + if line.startswith("name"): |
| 367 | + past_header = True |
| 368 | + |
| 369 | + app.router.lifespan_context = lifespan_wrapper |
| 370 | + |
308 | 371 | @classmethod |
309 | 372 | def run_webserver(cls) -> None: # pragma: no cover |
| 373 | + global_config_dict = get_global_config_dict() |
| 374 | + |
310 | 375 | server_config = cls.load_config_from_global_config() |
311 | 376 | server_client = ServerClient( |
312 | 377 | head_server_config=ServerClient.load_head_server_config(), |
313 | | - global_config_dict=get_global_config_dict(), |
| 378 | + global_config_dict=global_config_dict, |
314 | 379 | ) |
315 | 380 | server = cls(config=server_config, server_client=server_client) |
316 | 381 |
|
317 | 382 | app = server.setup_webserver() |
318 | 383 |
|
319 | | - class No200Filter(LoggingFilter): |
320 | | - def filter(self, record: LogRecord) -> bool: |
321 | | - msg = record.getMessage() |
322 | | - return not msg.strip().endswith("200") |
| 384 | + profiling_config = ProfilingMiddlewareConfig.model_validate(global_config_dict) |
| 385 | + if profiling_config.profiling_enabled: |
| 386 | + server.setup_profiling(app, profiling_config) |
323 | 387 |
|
324 | | - uvicorn_logger = getLogger("uvicorn.access") |
325 | | - uvicorn_logger.addFilter(No200Filter()) |
| 388 | + uvicorn_logging_cfg = UvicornLoggingConfig.model_validate(global_config_dict) |
| 389 | + if not uvicorn_logging_cfg.uvicorn_logging_show_200_ok: |
326 | 390 |
|
327 | | - print( |
328 | | - "Adding a uvicorn logging filter so that the logs aren't spammed with 200 OK messages. This is to help errors pop up better and filter out noise." |
329 | | - ) |
| 391 | + class No200Filter(LoggingFilter): |
| 392 | + def filter(self, record: LogRecord) -> bool: |
| 393 | + msg = record.getMessage() |
| 394 | + return not msg.strip().endswith("200") |
| 395 | + |
| 396 | + uvicorn_logger = getLogger("uvicorn.access") |
| 397 | + uvicorn_logger.addFilter(No200Filter()) |
| 398 | + |
| 399 | + print( |
| 400 | + "Adding a uvicorn logging filter so that the logs aren't spammed with 200 OK messages. This is to help errors pop up better and filter out noise." |
| 401 | + ) |
330 | 402 |
|
331 | 403 | uvicorn.run( |
332 | 404 | app, |
333 | 405 | host=server.config.host, |
334 | 406 | port=server.config.port, |
335 | | - # We don't have any explicit lifespan logic, so instead of defaulting to "auto" |
336 | | - # We just turn lifespan off |
337 | | - lifespan="off", |
338 | 407 | ) |
339 | 408 |
|
340 | 409 |
|
|
0 commit comments