Skip to content

Commit e1bb2fd

Browse files
[Bugfix] Support logprobs when using guided_json and other constrained decoding fields (vllm-project#4149)
1 parent 705578a commit e1bb2fd

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
723723
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
724724

725725

726+
@pytest.mark.parametrize("guided_decoding_backend",
727+
["outlines", "lm-format-enforcer"])
728+
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
729+
guided_decoding_backend: str):
730+
messages = [{
731+
"role": "system",
732+
"content": "you are a helpful assistant"
733+
}, {
734+
"role":
735+
"user",
736+
"content":
737+
"The best language for type-safe systems programming is "
738+
}]
739+
chat_completion = await client.chat.completions.create(
740+
model=MODEL_NAME,
741+
messages=messages,
742+
max_tokens=10,
743+
logprobs=True,
744+
top_logprobs=5,
745+
extra_body=dict(guided_choice=TEST_CHOICE,
746+
guided_decoding_backend=guided_decoding_backend))
747+
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
748+
749+
# -9999.0 is the minimum logprob returned by OpenAI
750+
assert all(
751+
isinstance(logprob, float) and logprob >= -9999.0
752+
for token_dict in top_logprobs
753+
for token, logprob in token_dict.items())
754+
755+
726756
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
727757
resp = await client.chat.completions.create(
728758
model=MODEL_NAME,

vllm/entrypoints/openai/serving_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def _create_logprobs(
116116

117117
if num_output_top_logprobs:
118118
logprobs.top_logprobs.append({
119-
p.decoded_token: p.logprob
119+
# Convert float("-inf") to the
120+
# JSON-serializable float that OpenAI uses
121+
p.decoded_token: max(p.logprob, -9999.0)
120122
for i, p in step_top_logprobs.items()
121123
} if step_top_logprobs else None)
122124

0 commit comments

Comments
 (0)