From 306a39f2a1871210b4545d16322e1d76a23de2ea Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Thu, 3 Jul 2025 22:30:03 +0800 Subject: [PATCH] cherry pick PR-3708 --- lmdeploy/serve/openai/api_server.py | 29 +++++++++++++++++++++-------- lmdeploy/serve/openai/protocol.py | 6 ++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index eef4cc1b34..40bcfd0eed 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -486,10 +486,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: completion_tokens=res.generate_token_len, total_tokens=total_tokens, ) + delta_token_ids = res.token_ids if res.token_ids is not None else [] delta_message = DeltaMessage(role='assistant', content=res.response) + if request.return_token_ids: + delta_message.gen_tokens = delta_token_ids if has_parser: current_text = current_text + res.response - delta_token_ids = res.token_ids if res.token_ids is not None else [] current_token_ids = current_token_ids + delta_token_ids if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: if res.finish_reason == 'stop' and streaming_tools is True: @@ -586,9 +588,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: assert final_res is not None choices = [] + chat_message = ChatMessage(role='assistant', + content=text, + tool_calls=tool_calls, + reasoning_content=reasoning_content) + if request.return_token_ids: + chat_message.gen_tokens = final_token_ids choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content), + message=chat_message, logprobs=logprobs, finish_reason=final_res.finish_reason, ) @@ -732,9 +740,11 @@ def create_stream_response_json(index: int, text: str, finish_reason: Optional[str] = None, logprobs: Optional[LogProbs] = None, + gen_tokens: Optional[List[int]] = None, usage: Optional[UsageInfo] = None) -> str: choice_data = CompletionResponseStreamChoice(index=index, text=text, + gen_tokens=gen_tokens, finish_reason=finish_reason, logprobs=logprobs) response = CompletionStreamResponse( @@ -771,8 +781,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: completion_tokens=final_res.generate_token_len, total_tokens=total_tokens, ) + gen_tokens = None + if request.return_token_ids: + gen_tokens = res.token_ids or [] response_json = create_stream_response_json(index=0, text=res.response, + gen_tokens=gen_tokens, finish_reason=res.finish_reason, logprobs=logprobs, usage=usage) @@ -822,12 +836,11 @@ async def _inner_call(i, generator): spaces_between_special_tokens=gen_config.spaces_between_special_tokens) assert final_res is not None - choice_data = CompletionResponseChoice( - index=i, - text=text, - finish_reason=final_res.finish_reason, - logprobs=logprobs, - ) + choice_data = CompletionResponseChoice(index=i, + text=text, + finish_reason=final_res.finish_reason, + logprobs=logprobs, + gen_tokens=final_token_ids if request.return_token_ids else None) choices[i] = choice_data if with_cache: diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 9eef46f694..d8e78a029c 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -149,6 +149,7 @@ class ChatCompletionRequest(BaseModel): min_new_tokens: Optional[int] = Field(default=None, examples=[None]) min_p: float = 0.0 enable_thinking: Optional[bool] = None + return_token_ids: Optional[bool] = False class FunctionCall(BaseModel): @@ -179,6 +180,7 @@ class ChatMessage(BaseModel): """Chat messages.""" role: str content: Optional[str] = None + gen_tokens: Optional[List[int]] = None reasoning_content: Optional[str] = Field(default=None, examples=[None]) tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -243,6 +245,7 @@ class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None reasoning_content: Optional[str] = None + gen_tokens: Optional[List[int]] = None tool_calls: List[DeltaToolCall] = Field(default_factory=list) @@ -300,6 +303,7 @@ class CompletionRequest(BaseModel): top_k: Optional[int] = 40 # for opencompass seed: Optional[int] = None min_p: float = 0.0 + return_token_ids: Optional[bool] = False class CompletionResponseChoice(BaseModel): @@ -307,6 +311,7 @@ class CompletionResponseChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None + gen_tokens: Optional[List[int]] = None finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None @@ -325,6 +330,7 @@ class CompletionResponseStreamChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None + gen_tokens: Optional[List[int]] = None finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None