From 6cb4e9017623286051a2322e89607357bce18695 Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Sat, 11 Apr 2026 00:33:50 -0600 Subject: [PATCH] initial commit --- src/s2_sdk/_client.py | 36 ++++++++++++++++++------------------ tests/test_client.py | 4 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/s2_sdk/_client.py b/src/s2_sdk/_client.py index 81e7cb3..b043446 100644 --- a/src/s2_sdk/_client.py +++ b/src/s2_sdk/_client.py @@ -171,12 +171,10 @@ async def unary_request( raise ReadTimeoutError("Request timed out") finally: if stream_id is not None: - nbytes = _take_all_unacked_flow_bytes(state) - if nbytes > 0: - try: - await conn.ack_data(stream_id, nbytes) - except Exception: - pass + try: + await conn.ack_all_data(stream_id, state) + except Exception: + pass if not state.ended.is_set(): await conn.reset_stream(stream_id) conn.release_stream(stream_id, state) @@ -263,12 +261,10 @@ async def _ack_stream_data(nbytes: int) -> None: pass # Ack remaining flow bytes to keep connection window healthy if stream_id is not None: - nbytes = _take_all_unacked_flow_bytes(state) - if nbytes > 0: - try: - await conn.ack_data(stream_id, nbytes) - except Exception: - pass + try: + await conn.ack_all_data(stream_id, state) + except Exception: + pass if not state.ended.is_set(): await conn.reset_stream(stream_id) conn.release_stream(stream_id, state) @@ -768,6 +764,16 @@ async def ack_data(self, stream_id: int, nbytes: int) -> None: self._h2.acknowledge_received_data(nbytes, stream_id) await self._flush_h2_data() + async def ack_all_data(self, stream_id: int, state: _StreamState) -> None: + """Acknowledge all received data for stream cleanup.""" + assert self._h2 is not None + async with self._write_lock: + nbytes = state.unacked_flow_bytes + state.unacked_flow_bytes = 0 + if nbytes > 0: + self._h2.acknowledge_received_data(nbytes, stream_id) + await self._flush_h2_data() + async def reset_stream(self, stream_id: int) -> None: """Send RST_STREAM to tell the peer to stop sending.""" assert self._h2 is not None @@ -989,12 +995,6 @@ def _queue_item_parts(item: tuple[bytes, int] | bytes) -> tuple[bytes, int]: return item, len(item) -def _take_all_unacked_flow_bytes(state: _StreamState) -> int: - nbytes = state.unacked_flow_bytes - state.unacked_flow_bytes = 0 - return nbytes - - def _parse_retry_after_ms(raw: str | None) -> float | None: if raw is None: return None diff --git a/tests/test_client.py b/tests/test_client.py index 04a3794..156fd57 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -209,7 +209,7 @@ async def test_unary_request_timeout_acks_and_resets_stream(): conn = AsyncMock() conn.send_headers = AsyncMock(return_value=1) conn.release_stream = MagicMock() - conn.ack_data = AsyncMock() + conn.ack_all_data = AsyncMock() conn.reset_stream = AsyncMock() pc = MagicMock() @@ -226,7 +226,7 @@ async def test_unary_request_timeout_acks_and_resets_stream(): with pytest.raises(ReadTimeoutError, match="Request timed out"): await client.unary_request("GET", "/v1/test") - conn.ack_data.assert_awaited_once_with(1, 11) + conn.ack_all_data.assert_awaited_once_with(1, state) conn.reset_stream.assert_awaited_once_with(1) conn.release_stream.assert_called_once_with(1, state) pc.touch_idle.assert_called_once()