Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/s2_sdk/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,10 @@ async def unary_request(
raise ReadTimeoutError("Request timed out")
finally:
if stream_id is not None:
nbytes = _take_all_unacked_flow_bytes(state)
if nbytes > 0:
try:
await conn.ack_data(stream_id, nbytes)
except Exception:
pass
try:
await conn.ack_all_data(stream_id, state)
except Exception:
pass
if not state.ended.is_set():
await conn.reset_stream(stream_id)
conn.release_stream(stream_id, state)
Expand Down Expand Up @@ -263,12 +261,10 @@ async def _ack_stream_data(nbytes: int) -> None:
pass
# Ack remaining flow bytes to keep connection window healthy
if stream_id is not None:
nbytes = _take_all_unacked_flow_bytes(state)
if nbytes > 0:
try:
await conn.ack_data(stream_id, nbytes)
except Exception:
pass
try:
await conn.ack_all_data(stream_id, state)
except Exception:
pass
if not state.ended.is_set():
await conn.reset_stream(stream_id)
conn.release_stream(stream_id, state)
Expand Down Expand Up @@ -768,6 +764,16 @@ async def ack_data(self, stream_id: int, nbytes: int) -> None:
self._h2.acknowledge_received_data(nbytes, stream_id)
await self._flush_h2_data()

async def ack_all_data(self, stream_id: int, state: _StreamState) -> None:
"""Acknowledge all received data for stream cleanup."""
assert self._h2 is not None
async with self._write_lock:
nbytes = state.unacked_flow_bytes
state.unacked_flow_bytes = 0
if nbytes > 0:
self._h2.acknowledge_received_data(nbytes, stream_id)
await self._flush_h2_data()

async def reset_stream(self, stream_id: int) -> None:
"""Send RST_STREAM to tell the peer to stop sending."""
assert self._h2 is not None
Expand Down Expand Up @@ -989,12 +995,6 @@ def _queue_item_parts(item: tuple[bytes, int] | bytes) -> tuple[bytes, int]:
return item, len(item)


def _take_all_unacked_flow_bytes(state: _StreamState) -> int:
nbytes = state.unacked_flow_bytes
state.unacked_flow_bytes = 0
return nbytes


def _parse_retry_after_ms(raw: str | None) -> float | None:
if raw is None:
return None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ async def test_unary_request_timeout_acks_and_resets_stream():
conn = AsyncMock()
conn.send_headers = AsyncMock(return_value=1)
conn.release_stream = MagicMock()
conn.ack_data = AsyncMock()
conn.ack_all_data = AsyncMock()
conn.reset_stream = AsyncMock()

pc = MagicMock()
Expand All @@ -226,7 +226,7 @@ async def test_unary_request_timeout_acks_and_resets_stream():
with pytest.raises(ReadTimeoutError, match="Request timed out"):
await client.unary_request("GET", "/v1/test")

conn.ack_data.assert_awaited_once_with(1, 11)
conn.ack_all_data.assert_awaited_once_with(1, state)
conn.reset_stream.assert_awaited_once_with(1)
conn.release_stream.assert_called_once_with(1, state)
pc.touch_idle.assert_called_once()
Expand Down
Loading