1414import asyncio
1515import atexit
1616import json
17+ import resource
1718from abc import abstractmethod
1819from contextlib import asynccontextmanager
1920from io import StringIO
2223from os import getenv
2324from pathlib import Path
2425from threading import Thread
26+ from traceback import print_exc
2527from typing import Literal , Optional , Tuple , Type , Union , Unpack
2628from uuid import uuid4
2729
3133from aiohttp import ClientResponse , ClientSession , ClientTimeout , DummyCookieJar , ServerDisconnectedError , TCPConnector
3234from aiohttp .client import _RequestOptions
3335from fastapi import FastAPI , Request , Response
36+ from fastapi .responses import JSONResponse
3437from omegaconf import DictConfig , OmegaConf
3538from pydantic import BaseModel , ConfigDict
3639from requests .exceptions import ConnectionError
@@ -62,7 +65,7 @@ class GlobalAIOHTTPAsyncClientConfig(BaseModel):
6265def get_global_aiohttp_client (
6366 global_config_dict_parser_config : Optional [GlobalConfigDictParserConfig ] = None ,
6467 global_config_dict_parser_cls : Type [GlobalConfigDictParser ] = GlobalConfigDictParser ,
65- ) -> ClientSession :
68+ ) -> ClientSession : # pragma: no cover
6669 global _GLOBAL_AIOHTTP_CLIENT
6770
6871 if _GLOBAL_AIOHTTP_CLIENT is not None :
@@ -77,7 +80,7 @@ def get_global_aiohttp_client(
7780 return set_global_aiohttp_client (cfg )
7881
7982
80- def set_global_aiohttp_client (cfg : GlobalAIOHTTPAsyncClientConfig ) -> ClientSession :
83+ def set_global_aiohttp_client (cfg : GlobalAIOHTTPAsyncClientConfig ) -> ClientSession : # pragma: no cover
8184 assert not is_global_aiohttp_client_setup (), (
8285 "There is already a global aiohttp client setup. Please refactor your code or call `global_aiohttp_client_exit` if you want to explicitly re-make the client!"
8386 )
@@ -97,11 +100,11 @@ def set_global_aiohttp_client(cfg: GlobalAIOHTTPAsyncClientConfig) -> ClientSess
97100 return _GLOBAL_AIOHTTP_CLIENT
98101
99102
100- def is_global_aiohttp_client_setup () -> bool :
103+ def is_global_aiohttp_client_setup () -> bool : # pragma: no cover
101104 return _GLOBAL_AIOHTTP_CLIENT is not None
102105
103106
104- def global_aiohttp_client_exit ():
107+ def global_aiohttp_client_exit (): # pragma: no cover
105108 if not is_global_aiohttp_client_setup ():
106109 return
107110
@@ -118,7 +121,9 @@ def global_aiohttp_client_exit():
118121MAX_NUM_TRIES = 3
119122
120123
121- async def request (method : str , url : str , ** kwargs : Unpack [_RequestOptions ]) -> ClientResponse :
124+ async def request (
125+ method : str , url : str , _internal : bool = False , ** kwargs : Unpack [_RequestOptions ]
126+ ) -> ClientResponse : # pragma: no cover
122127 client = get_global_aiohttp_client ()
123128 num_tries = 1
124129 while True :
@@ -127,18 +132,27 @@ async def request(method: str, url: str, **kwargs: Unpack[_RequestOptions]) -> C
127132 except ServerDisconnectedError :
128133 await asyncio .sleep (0.5 )
129134 except Exception as e :
130- print (
131- f"""Hit an exception while making a request (try { num_tries } ): { type (e )} : { e }
135+ # Don't increment internal since we know we are ok. If we are not, the head server will shut everything down anyways.
136+ if not _internal :
137+ print (
138+ f"""Hit an exception while making a request (try { num_tries } ): { type (e )} : { e }
132139Sleeping 0.5s and retrying...
133140"""
134- )
135- if num_tries >= MAX_NUM_TRIES :
136- raise e
141+ )
142+ if num_tries >= MAX_NUM_TRIES :
143+ raise e
144+
145+ num_tries += 1
137146
138- num_tries += 1
139147 await asyncio .sleep (0.5 )
140148
141149
150+ def raise_for_status (response : ClientResponse ) -> None : # pragma: no cover
151+ if not response .ok :
152+ print (response .content )
153+ response .raise_for_status ()
154+
155+
142156DEFAULT_HEAD_SERVER_PORT = 11000
143157
144158ServerStatus = Union [Literal ["success" ], Literal ["connection_error" ], Literal ["timeout" ], Literal ["unknown_error" ]]
@@ -193,7 +207,7 @@ async def request(
193207 if isinstance (json_obj , BaseModel ):
194208 kwargs ["json" ] = json_obj .model_dump (exclude_unset = True )
195209
196- return await request (method = method , url = f"{ base_url } { url_path } " , ** kwargs )
210+ return await request (method = method , url = f"{ base_url } { url_path } " , _internal = True , ** kwargs )
197211
198212 async def get (
199213 self ,
@@ -324,6 +338,24 @@ async def add_session_id(request: Request, call_next): # pragma: no cover
324338 session_middleware_key = self .get_session_middleware_key ()
325339 app .add_middleware (SessionMiddleware , secret_key = session_middleware_key , session_cookie = session_middleware_key )
326340
341+ def setup_exception_middleware (self , app : FastAPI ) -> None : # pragma: no cover
342+ @app .middleware ("http" )
343+ async def exception_handling_middleware (request : Request , call_next ):
344+ try :
345+ return await call_next (request )
346+ except Exception as e :
347+ print_exc ()
348+ print (
349+ f"🚨 Caught an exception printed above in { self .config .name } ({ self .__class__ .__name__ } ). If you expect this to be fed back into this model, the exception repr i.e. `repr(e)` is returned to the model. However, please make sure this exception is caught in your server and returned to the model as appropriate. See https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception"
350+ )
351+ return JSONResponse (content = repr (e ), status_code = 500 )
352+ except :
353+ print_exc ()
354+ print (
355+ f"🚨 Caught an unknown exception printed above in { self .config .name } ({ self .__class__ .__name__ } ). If you expect this to be fed back into this model, nothing meaningful is returned to the model. Please make sure this exception is caught in your server and returned to the model as appropriate. See https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception"
356+ )
357+ return JSONResponse (content = "An unknown error occurred" , status_code = 500 )
358+
327359 def setup_profiling (self , app : FastAPI , profiling_config : ProfilingMiddlewareConfig ) -> None : # pragma: no cover
328360 base_profile_dir = Path (PARENT_DIR ) / profiling_config .profiling_results_dirpath
329361 server_profile_path = (base_profile_dir / self .get_session_middleware_key ()).with_suffix (".log" )
@@ -332,18 +364,7 @@ def setup_profiling(self, app: FastAPI, profiling_config: ProfilingMiddlewareCon
332364
333365 main_app_lifespan = app .router .lifespan_context
334366
335- @asynccontextmanager
336- async def lifespan_wrapper (app ):
337- yappi .set_clock_type ("WALL" )
338- yappi .start ()
339- print (f"🔍 Enabled profiling for { self .config .name } " )
340-
341- async with main_app_lifespan (app ) as maybe_state :
342- yield maybe_state
343-
344- print (f"🛑 Stopping profiler for { self .config .name } . Check { server_profile_path } for the metrics!" )
345- yappi .stop ()
346-
367+ def _dump_yappi_stats () -> str :
347368 buffer = StringIO ()
348369 yappi .get_func_stats ().print_all (
349370 out = buffer ,
@@ -357,17 +378,56 @@ async def lifespan_wrapper(app):
357378 )
358379
359380 buffer .seek (0 )
360- with open (server_profile_path , "w" ) as f :
361- past_header = False
362- for line in buffer :
363- if not past_header or self .config .entrypoint in line :
364- f .write (line )
381+ res = ""
382+ past_header = False
383+ for line in buffer :
384+ if not past_header or self .config .entrypoint in line :
385+ res += line
386+
387+ if line .startswith ("name" ):
388+ past_header = True
365389
366- if line .startswith ("name" ):
367- past_header = True
390+ return res
391+
392+ @asynccontextmanager
393+ async def lifespan_wrapper (app ):
394+ yappi .set_clock_type ("CPU" )
395+ yappi .start ()
396+ print (f"🔍 Enabled profiling for { self .config .name } " )
397+
398+ async with main_app_lifespan (app ) as maybe_state :
399+ yield maybe_state
400+
401+ print (f"🛑 Stopping profiler for { self .config .name } . Check { server_profile_path } for the metrics!" )
402+ yappi .stop ()
403+
404+ with open (server_profile_path , "w" ) as f :
405+ f .write (_dump_yappi_stats ())
368406
369407 app .router .lifespan_context = lifespan_wrapper
370408
409+ @app .get ("/stats" )
410+ def stats ():
411+ return Response (_dump_yappi_stats ())
412+
413+ def set_ulimit (self , target_soft_limit : int = 65535 ): # pragma: no cover
414+ # From https://github.com/vllm-project/vllm/blob/fed8a9b107df3e27d57728c6911c7d308b871477/vllm/utils/__init__.py#L2790
415+ resource_type = resource .RLIMIT_NOFILE
416+ current_soft , current_hard = resource .getrlimit (resource_type )
417+
418+ if current_soft < target_soft_limit :
419+ try :
420+ resource .setrlimit (resource_type , (target_soft_limit , current_hard ))
421+ except ValueError as e :
422+ print (
423+ "Found ulimit of %s and failed to automatically increase "
424+ "with error %s. This can cause fd limit errors like "
425+ "`OSError: [Errno 24] Too many open files`. Consider "
426+ "increasing with ulimit -n" ,
427+ current_soft ,
428+ e ,
429+ )
430+
371431 @classmethod
372432 def run_webserver (cls ) -> None : # pragma: no cover
373433 global_config_dict = get_global_config_dict ()
@@ -380,6 +440,8 @@ def run_webserver(cls) -> None: # pragma: no cover
380440 server = cls (config = server_config , server_client = server_client )
381441
382442 app = server .setup_webserver ()
443+ server .set_ulimit ()
444+ server .setup_exception_middleware (app )
383445
384446 profiling_config = ProfilingMiddlewareConfig .model_validate (global_config_dict )
385447 if profiling_config .profiling_enabled :
0 commit comments