Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 86 additions & 88 deletions src/groq/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,50 +55,49 @@ def __stream__(self) -> Iterator[_T]:
process_data = self._client._process_response_data
iterator = self._iter_events()

try:
for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
response.close()
for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
# The stream needs to be fully consumed to close the response
for _sse in iterator:
...

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Stream not consumed when exception occurs during streaming

The removal of the try-finally block creates a bug where the stream-consuming code at the end of __stream__ (lines 98-100 for sync, 199-201 for async) is never reached if an APIError is raised during iteration. When an error SSE is received and APIError is raised, the exception propagates immediately without consuming the remaining stream, which could cause the same hanging issue this PR attempts to fix. The stream consumption logic needs to be wrapped in a try-finally to ensure it executes even when exceptions occur.

Additional Locations (1)

Fix in Cursor Fix in Web


def __enter__(self) -> Self:
return self
Expand Down Expand Up @@ -157,50 +156,49 @@ async def __stream__(self) -> AsyncIterator[_T]:
process_data = self._client._process_response_data
iterator = self._iter_events()

try:
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
await response.aclose()
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
# The stream needs to be fully consumed to close the response
async for _sse in iterator:
...

async def __aenter__(self) -> Self:
return self
Expand Down