diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 0a8841f946..07fcc41396 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -316,6 +316,18 @@ async def flush_cache(): ) +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0dab8fc8cc..765b44eea0 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -123,6 +123,9 @@ def __init__( self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + # 交互式请求 event self.flush_cache_event: Optional[asyncio.Event] = None return @@ -302,6 +305,10 @@ async def generate( # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) @@ -832,6 +839,23 @@ async def flush_cache(self): self.flush_cache_event.clear() return ret + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + class ReqStatus: def __init__(