Skip to content

Commit 72e1e15

Browse files
karpetrosyanstainless-app[bot]
authored andcommitted
fix(structured outputs): resolve memory leak in parse methods (#2860)
1 parent d81ee8f commit 72e1e15

File tree

7 files changed

+92
-135
lines changed

7 files changed

+92
-135
lines changed

src/openai/lib/_parsing/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
validate_input_tools as validate_input_tools,
77
parse_chat_completion as parse_chat_completion,
88
get_input_tool_by_name as get_input_tool_by_name,
9-
solve_response_format_t as solve_response_format_t,
109
parse_function_tool_arguments as parse_function_tool_arguments,
1110
type_to_response_format_param as type_to_response_format_param,
1211
)

src/openai/lib/_parsing/_completions.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def parse_chat_completion(
138138

139139
choices.append(
140140
construct_type_unchecked(
141-
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
141+
type_=ParsedChoice[ResponseFormatT],
142142
value={
143143
**choice.to_dict(),
144144
"message": {
@@ -153,15 +153,12 @@ def parse_chat_completion(
153153
)
154154
)
155155

156-
return cast(
157-
ParsedChatCompletion[ResponseFormatT],
158-
construct_type_unchecked(
159-
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
160-
value={
161-
**chat_completion.to_dict(),
162-
"choices": choices,
163-
},
164-
),
156+
return construct_type_unchecked(
157+
type_=ParsedChatCompletion[ResponseFormatT],
158+
value={
159+
**chat_completion.to_dict(),
160+
"choices": choices,
161+
},
165162
)
166163

167164

@@ -201,20 +198,6 @@ def maybe_parse_content(
201198
return None
202199

203200

204-
def solve_response_format_t(
205-
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
206-
) -> type[ResponseFormatT]:
207-
"""Return the runtime type for the given response format.
208-
209-
If no response format is given, or if we won't auto-parse the response format
210-
then we default to `None`.
211-
"""
212-
if has_rich_response_format(response_format):
213-
return response_format
214-
215-
return cast("type[ResponseFormatT]", _default_response_format)
216-
217-
218201
def has_parseable_input(
219202
*,
220203
response_format: type | ResponseFormatParam | Omit,

src/openai/lib/_parsing/_responses.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import json
4-
from typing import TYPE_CHECKING, Any, List, Iterable, cast
4+
from typing import TYPE_CHECKING, List, Iterable, cast
55
from typing_extensions import TypeVar, assert_never
66

77
import pydantic
@@ -12,7 +12,7 @@
1212
from ..._compat import PYDANTIC_V1, model_parse_json
1313
from ..._models import construct_type_unchecked
1414
from .._pydantic import is_basemodel_type, is_dataclass_like_type
15-
from ._completions import solve_response_format_t, type_to_response_format_param
15+
from ._completions import type_to_response_format_param
1616
from ...types.responses import (
1717
Response,
1818
ToolParam,
@@ -56,7 +56,6 @@ def parse_response(
5656
input_tools: Iterable[ToolParam] | Omit | None,
5757
response: Response | ParsedResponse[object],
5858
) -> ParsedResponse[TextFormatT]:
59-
solved_t = solve_response_format_t(text_format)
6059
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
6160

6261
for output in response.output:
@@ -69,7 +68,7 @@ def parse_response(
6968

7069
content_list.append(
7170
construct_type_unchecked(
72-
type_=cast(Any, ParsedResponseOutputText)[solved_t],
71+
type_=ParsedResponseOutputText[TextFormatT],
7372
value={
7473
**item.to_dict(),
7574
"parsed": parse_text(item.text, text_format=text_format),
@@ -79,7 +78,7 @@ def parse_response(
7978

8079
output_list.append(
8180
construct_type_unchecked(
82-
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
81+
type_=ParsedResponseOutputMessage[TextFormatT],
8382
value={
8483
**output.to_dict(),
8584
"content": content_list,
@@ -123,15 +122,12 @@ def parse_response(
123122
else:
124123
output_list.append(output)
125124

126-
return cast(
127-
ParsedResponse[TextFormatT],
128-
construct_type_unchecked(
129-
type_=cast(Any, ParsedResponse)[solved_t],
130-
value={
131-
**response.to_dict(),
132-
"output": output_list,
133-
},
134-
),
125+
return construct_type_unchecked(
126+
type_=ParsedResponse[TextFormatT],
127+
value={
128+
**response.to_dict(),
129+
"output": output_list,
130+
},
135131
)
136132

137133

src/openai/lib/streaming/chat/_completions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
maybe_parse_content,
3434
parse_chat_completion,
3535
get_input_tool_by_name,
36-
solve_response_format_t,
3736
parse_function_tool_arguments,
3837
)
3938
from ...._streaming import Stream, AsyncStream
@@ -663,7 +662,7 @@ def _content_done_events(
663662
# type variable, e.g. `ContentDoneEvent[MyModelType]`
664663
cast( # pyright: ignore[reportUnnecessaryCast]
665664
"type[ContentDoneEvent[ResponseFormatT]]",
666-
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
665+
cast(Any, ContentDoneEvent),
667666
),
668667
type="content.done",
669668
content=choice_snapshot.message.content,

tests/lib/chat/test_completions.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pyte
5050

5151
assert print_obj(completion, monkeypatch) == snapshot(
5252
"""\
53-
ParsedChatCompletion[NoneType](
53+
ParsedChatCompletion(
5454
choices=[
55-
ParsedChoice[NoneType](
55+
ParsedChoice(
5656
finish_reason='stop',
5757
index=0,
5858
logprobs=None,
59-
message=ParsedChatCompletionMessage[NoneType](
59+
message=ParsedChatCompletionMessage(
6060
annotations=None,
6161
audio=None,
6262
content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
@@ -120,13 +120,13 @@ class Location(BaseModel):
120120

121121
assert print_obj(completion, monkeypatch) == snapshot(
122122
"""\
123-
ParsedChatCompletion[Location](
123+
ParsedChatCompletion(
124124
choices=[
125-
ParsedChoice[Location](
125+
ParsedChoice(
126126
finish_reason='stop',
127127
index=0,
128128
logprobs=None,
129-
message=ParsedChatCompletionMessage[Location](
129+
message=ParsedChatCompletionMessage(
130130
annotations=None,
131131
audio=None,
132132
content='{"city":"San Francisco","temperature":65,"units":"f"}',
@@ -191,13 +191,13 @@ class Location(BaseModel):
191191

192192
assert print_obj(completion, monkeypatch) == snapshot(
193193
"""\
194-
ParsedChatCompletion[Location](
194+
ParsedChatCompletion(
195195
choices=[
196-
ParsedChoice[Location](
196+
ParsedChoice(
197197
finish_reason='stop',
198198
index=0,
199199
logprobs=None,
200-
message=ParsedChatCompletionMessage[Location](
200+
message=ParsedChatCompletionMessage(
201201
annotations=None,
202202
audio=None,
203203
content='{"city":"San Francisco","temperature":65,"units":"f"}',
@@ -266,11 +266,11 @@ class ColorDetection(BaseModel):
266266

267267
assert print_obj(completion.choices[0], monkeypatch) == snapshot(
268268
"""\
269-
ParsedChoice[ColorDetection](
269+
ParsedChoice(
270270
finish_reason='stop',
271271
index=0,
272272
logprobs=None,
273-
message=ParsedChatCompletionMessage[ColorDetection](
273+
message=ParsedChatCompletionMessage(
274274
annotations=None,
275275
audio=None,
276276
content='{"color":"red","hex_color_code":"#FF0000"}',
@@ -317,11 +317,11 @@ class Location(BaseModel):
317317
assert print_obj(completion.choices, monkeypatch) == snapshot(
318318
"""\
319319
[
320-
ParsedChoice[Location](
320+
ParsedChoice(
321321
finish_reason='stop',
322322
index=0,
323323
logprobs=None,
324-
message=ParsedChatCompletionMessage[Location](
324+
message=ParsedChatCompletionMessage(
325325
annotations=None,
326326
audio=None,
327327
content='{"city":"San Francisco","temperature":64,"units":"f"}',
@@ -332,11 +332,11 @@ class Location(BaseModel):
332332
tool_calls=None
333333
)
334334
),
335-
ParsedChoice[Location](
335+
ParsedChoice(
336336
finish_reason='stop',
337337
index=1,
338338
logprobs=None,
339-
message=ParsedChatCompletionMessage[Location](
339+
message=ParsedChatCompletionMessage(
340340
annotations=None,
341341
audio=None,
342342
content='{"city":"San Francisco","temperature":65,"units":"f"}',
@@ -347,11 +347,11 @@ class Location(BaseModel):
347347
tool_calls=None
348348
)
349349
),
350-
ParsedChoice[Location](
350+
ParsedChoice(
351351
finish_reason='stop',
352352
index=2,
353353
logprobs=None,
354-
message=ParsedChatCompletionMessage[Location](
354+
message=ParsedChatCompletionMessage(
355355
annotations=None,
356356
audio=None,
357357
content='{"city":"San Francisco","temperature":63.0,"units":"f"}',
@@ -397,13 +397,13 @@ class CalendarEvent:
397397

398398
assert print_obj(completion, monkeypatch) == snapshot(
399399
"""\
400-
ParsedChatCompletion[CalendarEvent](
400+
ParsedChatCompletion(
401401
choices=[
402-
ParsedChoice[CalendarEvent](
402+
ParsedChoice(
403403
finish_reason='stop',
404404
index=0,
405405
logprobs=None,
406-
message=ParsedChatCompletionMessage[CalendarEvent](
406+
message=ParsedChatCompletionMessage(
407407
annotations=None,
408408
audio=None,
409409
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
@@ -462,11 +462,11 @@ def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, m
462462

463463
assert print_obj(completion.choices[0], monkeypatch) == snapshot(
464464
"""\
465-
ParsedChoice[Query](
465+
ParsedChoice(
466466
finish_reason='tool_calls',
467467
index=0,
468468
logprobs=None,
469-
message=ParsedChatCompletionMessage[Query](
469+
message=ParsedChatCompletionMessage(
470470
annotations=None,
471471
audio=None,
472472
content=None,
@@ -576,11 +576,11 @@ class Location(BaseModel):
576576
assert print_obj(completion.choices, monkeypatch) == snapshot(
577577
"""\
578578
[
579-
ParsedChoice[Location](
579+
ParsedChoice(
580580
finish_reason='stop',
581581
index=0,
582582
logprobs=None,
583-
message=ParsedChatCompletionMessage[Location](
583+
message=ParsedChatCompletionMessage(
584584
annotations=None,
585585
audio=None,
586586
content=None,
@@ -627,11 +627,11 @@ class GetWeatherArgs(BaseModel):
627627
assert print_obj(completion.choices, monkeypatch) == snapshot(
628628
"""\
629629
[
630-
ParsedChoice[NoneType](
630+
ParsedChoice(
631631
finish_reason='tool_calls',
632632
index=0,
633633
logprobs=None,
634-
message=ParsedChatCompletionMessage[NoneType](
634+
message=ParsedChatCompletionMessage(
635635
annotations=None,
636636
audio=None,
637637
content=None,
@@ -701,11 +701,11 @@ class GetStockPrice(BaseModel):
701701
assert print_obj(completion.choices, monkeypatch) == snapshot(
702702
"""\
703703
[
704-
ParsedChoice[NoneType](
704+
ParsedChoice(
705705
finish_reason='tool_calls',
706706
index=0,
707707
logprobs=None,
708-
message=ParsedChatCompletionMessage[NoneType](
708+
message=ParsedChatCompletionMessage(
709709
annotations=None,
710710
audio=None,
711711
content=None,
@@ -784,11 +784,11 @@ def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch:
784784
assert print_obj(completion.choices, monkeypatch) == snapshot(
785785
"""\
786786
[
787-
ParsedChoice[NoneType](
787+
ParsedChoice(
788788
finish_reason='tool_calls',
789789
index=0,
790790
logprobs=None,
791-
message=ParsedChatCompletionMessage[NoneType](
791+
message=ParsedChatCompletionMessage(
792792
annotations=None,
793793
audio=None,
794794
content=None,
@@ -866,13 +866,13 @@ class Location(BaseModel):
866866
assert isinstance(message.parsed.city, str)
867867
assert print_obj(completion, monkeypatch) == snapshot(
868868
"""\
869-
ParsedChatCompletion[Location](
869+
ParsedChatCompletion(
870870
choices=[
871-
ParsedChoice[Location](
871+
ParsedChoice(
872872
finish_reason='stop',
873873
index=0,
874874
logprobs=None,
875-
message=ParsedChatCompletionMessage[Location](
875+
message=ParsedChatCompletionMessage(
876876
annotations=None,
877877
audio=None,
878878
content='{"city":"San Francisco","temperature":58,"units":"f"}',
@@ -943,13 +943,13 @@ class Location(BaseModel):
943943
assert isinstance(message.parsed.city, str)
944944
assert print_obj(completion, monkeypatch) == snapshot(
945945
"""\
946-
ParsedChatCompletion[Location](
946+
ParsedChatCompletion(
947947
choices=[
948-
ParsedChoice[Location](
948+
ParsedChoice(
949949
finish_reason='stop',
950950
index=0,
951951
logprobs=None,
952-
message=ParsedChatCompletionMessage[Location](
952+
message=ParsedChatCompletionMessage(
953953
annotations=None,
954954
audio=None,
955955
content='{"city":"San Francisco","temperature":65,"units":"f"}',

0 commit comments

Comments
 (0)