Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
completion_tokens=res.generate_token_len,
total_tokens=total_tokens,
)
delta_token_ids = res.token_ids if res.token_ids is not None else []
delta_message = DeltaMessage(role='assistant', content=res.response)
if request.return_token_ids:
delta_message.gen_tokens = delta_token_ids
if has_parser:
current_text = current_text + res.response
delta_token_ids = res.token_ids if res.token_ids is not None else []
current_token_ids = current_token_ids + delta_token_ids
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
if res.finish_reason == 'stop' and streaming_tools is True:
Expand Down Expand Up @@ -586,9 +588,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

assert final_res is not None
choices = []
chat_message = ChatMessage(role='assistant',
content=text,
tool_calls=tool_calls,
reasoning_content=reasoning_content)
if request.return_token_ids:
chat_message.gen_tokens = final_token_ids
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),
message=chat_message,
logprobs=logprobs,
finish_reason=final_res.finish_reason,
)
Expand Down Expand Up @@ -732,9 +740,11 @@ def create_stream_response_json(index: int,
text: str,
finish_reason: Optional[str] = None,
logprobs: Optional[LogProbs] = None,
gen_tokens: Optional[List[int]] = None,
usage: Optional[UsageInfo] = None) -> str:
choice_data = CompletionResponseStreamChoice(index=index,
text=text,
gen_tokens=gen_tokens,
finish_reason=finish_reason,
logprobs=logprobs)
response = CompletionStreamResponse(
Expand Down Expand Up @@ -771,8 +781,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
)
gen_tokens = None
if request.return_token_ids:
gen_tokens = res.token_ids or []
response_json = create_stream_response_json(index=0,
text=res.response,
gen_tokens=gen_tokens,
finish_reason=res.finish_reason,
logprobs=logprobs,
usage=usage)
Expand Down Expand Up @@ -822,12 +836,11 @@ async def _inner_call(i, generator):
spaces_between_special_tokens=gen_config.spaces_between_special_tokens)

assert final_res is not None
choice_data = CompletionResponseChoice(
index=i,
text=text,
finish_reason=final_res.finish_reason,
logprobs=logprobs,
)
choice_data = CompletionResponseChoice(index=i,
text=text,
finish_reason=final_res.finish_reason,
logprobs=logprobs,
gen_tokens=final_token_ids if request.return_token_ids else None)
choices[i] = choice_data

if with_cache:
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class ChatCompletionRequest(BaseModel):
min_new_tokens: Optional[int] = Field(default=None, examples=[None])
min_p: float = 0.0
enable_thinking: Optional[bool] = None
return_token_ids: Optional[bool] = False


class FunctionCall(BaseModel):
Expand Down Expand Up @@ -179,6 +180,7 @@ class ChatMessage(BaseModel):
"""Chat messages."""
role: str
content: Optional[str] = None
gen_tokens: Optional[List[int]] = None
reasoning_content: Optional[str] = Field(default=None, examples=[None])
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])

Expand Down Expand Up @@ -243,6 +245,7 @@ class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
gen_tokens: Optional[List[int]] = None
tool_calls: List[DeltaToolCall] = Field(default_factory=list)


Expand Down Expand Up @@ -300,13 +303,15 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = 40 # for opencompass
seed: Optional[int] = None
min_p: float = 0.0
return_token_ids: Optional[bool] = False


class CompletionResponseChoice(BaseModel):
"""Completion response choices."""
index: int
text: str
logprobs: Optional[LogProbs] = None
gen_tokens: Optional[List[int]] = None
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None


Expand All @@ -325,6 +330,7 @@ class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
gen_tokens: Optional[List[int]] = None
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None


Expand Down