Skip to content

Commit d771bb7

Browse files
adriangbmikkelduif
andauthored
Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse (#2620)
* Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse Fixes #2516 * add test * fmt * Update tests/middleware/test_base.py Co-authored-by: Mikkel Duif <mikkel@duifs.dk> * add test for line now missing coverage * more coverage, fix test * add comment * fmt * tweak test * fix * fix coverage * relint --------- Co-authored-by: Mikkel Duif <mikkel@duifs.dk>
1 parent c78c9aa commit d771bb7

File tree

2 files changed

+168
-36
lines changed

2 files changed

+168
-36
lines changed

starlette/middleware/base.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from anyio.abc import ObjectReceiveStream, ObjectSendStream
77

88
from starlette._utils import collapse_excgroups
9-
from starlette.background import BackgroundTask
109
from starlette.requests import ClientDisconnect, Request
11-
from starlette.responses import ContentStream, Response, StreamingResponse
10+
from starlette.responses import AsyncContentStream, Response
1211
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1312

1413
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
@@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message:
5655
# at this point a disconnect is all that we should be receiving
5756
# if we get something else, things went wrong somewhere
5857
raise RuntimeError(f"Unexpected message received: {msg['type']}")
58+
self._wrapped_rcv_disconnected = True
5959
return msg
6060

6161
# wrapped_rcv state 3: not yet consumed
@@ -198,20 +198,33 @@ async def dispatch(
198198
raise NotImplementedError() # pragma: no cover
199199

200200

201-
class _StreamingResponse(StreamingResponse):
201+
class _StreamingResponse(Response):
202202
def __init__(
203203
self,
204-
content: ContentStream,
204+
content: AsyncContentStream,
205205
status_code: int = 200,
206206
headers: typing.Mapping[str, str] | None = None,
207207
media_type: str | None = None,
208-
background: BackgroundTask | None = None,
209208
info: typing.Mapping[str, typing.Any] | None = None,
210209
) -> None:
211-
self._info = info
212-
super().__init__(content, status_code, headers, media_type, background)
210+
self.info = info
211+
self.body_iterator = content
212+
self.status_code = status_code
213+
self.media_type = media_type
214+
self.init_headers(headers)
213215

214-
async def stream_response(self, send: Send) -> None:
215-
if self._info:
216-
await send({"type": "http.response.debug", "info": self._info})
217-
return await super().stream_response(send)
216+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
217+
if self.info is not None:
218+
await send({"type": "http.response.debug", "info": self.info})
219+
await send(
220+
{
221+
"type": "http.response.start",
222+
"status": self.status_code,
223+
"headers": self.raw_headers,
224+
}
225+
)
226+
227+
async for chunk in self.body_iterator:
228+
await send({"type": "http.response.body", "body": chunk, "more_body": True})
229+
230+
await send({"type": "http.response.body", "body": b"", "more_body": False})

tests/middleware/test_base.py

Lines changed: 144 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import (
66
Any,
77
AsyncGenerator,
8+
AsyncIterator,
89
Generator,
910
)
1011

@@ -16,7 +17,7 @@
1617
from starlette.background import BackgroundTask
1718
from starlette.middleware import Middleware, _MiddlewareClass
1819
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
19-
from starlette.requests import Request
20+
from starlette.requests import ClientDisconnect, Request
2021
from starlette.responses import PlainTextResponse, Response, StreamingResponse
2122
from starlette.routing import Route, WebSocketRoute
2223
from starlette.testclient import TestClient
@@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse:
260261
@pytest.mark.anyio
261262
async def test_run_background_tasks_even_if_client_disconnects() -> None:
262263
# test for https://github.com/encode/starlette/issues/1438
263-
request_body_sent = False
264264
response_complete = anyio.Event()
265265
background_task_run = anyio.Event()
266266

@@ -293,13 +293,7 @@ async def passthrough(
293293
}
294294

295295
async def receive() -> Message:
296-
nonlocal request_body_sent
297-
if not request_body_sent:
298-
request_body_sent = True
299-
return {"type": "http.request", "body": b"", "more_body": False}
300-
# We simulate a client that disconnects immediately after receiving the response
301-
await response_complete.wait()
302-
return {"type": "http.disconnect"}
296+
raise NotImplementedError("Should not be called!") # pragma: no cover
303297

304298
async def send(message: Message) -> None:
305299
if message["type"] == "http.response.body":
@@ -313,7 +307,6 @@ async def send(message: Message) -> None:
313307

314308
@pytest.mark.anyio
315309
async def test_do_not_block_on_background_tasks() -> None:
316-
request_body_sent = False
317310
response_complete = anyio.Event()
318311
events: list[str | Message] = []
319312

@@ -345,12 +338,7 @@ async def passthrough(
345338
}
346339

347340
async def receive() -> Message:
348-
nonlocal request_body_sent
349-
if not request_body_sent:
350-
request_body_sent = True
351-
return {"type": "http.request", "body": b"", "more_body": False}
352-
await response_complete.wait()
353-
return {"type": "http.disconnect"}
341+
raise NotImplementedError("Should not be called!") # pragma: no cover
354342

355343
async def send(message: Message) -> None:
356344
if message["type"] == "http.response.body":
@@ -379,7 +367,6 @@ async def send(message: Message) -> None:
379367
@pytest.mark.anyio
380368
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
381369
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
382-
request_body_sent = False
383370
response_complete = anyio.Event()
384371
context_manager_exited = anyio.Event()
385372

@@ -424,13 +411,7 @@ async def passthrough(
424411
}
425412

426413
async def receive() -> Message:
427-
nonlocal request_body_sent
428-
if not request_body_sent:
429-
request_body_sent = True
430-
return {"type": "http.request", "body": b"", "more_body": False}
431-
# We simulate a client that disconnects immediately after receiving the response
432-
await response_complete.wait()
433-
return {"type": "http.disconnect"}
414+
raise NotImplementedError("Should not be called!") # pragma: no cover
434415

435416
async def send(message: Message) -> None:
436417
if message["type"] == "http.response.body":
@@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
778759
yield {"type": "http.request", "body": b"1", "more_body": True}
779760
yield {"type": "http.request", "body": b"2", "more_body": True}
780761
yield {"type": "http.request", "body": b"3"}
781-
await anyio.sleep(float("inf"))
762+
raise AssertionError( # pragma: no cover
763+
"Should not be called, no need to poll for disconnect"
764+
)
782765

783766
sent: list[Message] = []
784767

@@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response:
10331016
resp.raise_for_status()
10341017

10351018
assert bodies == [b"Hello, World!-foo"]
1019+
1020+
1021+
@pytest.mark.anyio
1022+
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
1023+
class MyMiddleware(BaseHTTPMiddleware):
1024+
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
1025+
self.version = version
1026+
self.events = events
1027+
super().__init__(app)
1028+
1029+
async def dispatch(
1030+
self, request: Request, call_next: RequestResponseEndpoint
1031+
) -> Response:
1032+
self.events.append(f"{self.version}:STARTED")
1033+
res = await call_next(request)
1034+
self.events.append(f"{self.version}:COMPLETED")
1035+
return res
1036+
1037+
async def sleepy(request: Request) -> Response:
1038+
try:
1039+
await request.body()
1040+
except ClientDisconnect:
1041+
pass
1042+
else: # pragma: no cover
1043+
raise AssertionError("Should have raised ClientDisconnect")
1044+
return Response(b"")
1045+
1046+
events: list[str] = []
1047+
1048+
app = Starlette(
1049+
routes=[Route("/", sleepy)],
1050+
middleware=[
1051+
Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
1052+
],
1053+
)
1054+
1055+
scope = {
1056+
"type": "http",
1057+
"version": "3",
1058+
"method": "GET",
1059+
"path": "/",
1060+
}
1061+
1062+
async def receive() -> AsyncIterator[Message]:
1063+
yield {"type": "http.disconnect"}
1064+
1065+
sent: list[Message] = []
1066+
1067+
async def send(message: Message) -> None:
1068+
sent.append(message)
1069+
1070+
await app(scope, receive().__anext__, send)
1071+
1072+
assert events == [
1073+
"1:STARTED",
1074+
"2:STARTED",
1075+
"3:STARTED",
1076+
"4:STARTED",
1077+
"5:STARTED",
1078+
"6:STARTED",
1079+
"7:STARTED",
1080+
"8:STARTED",
1081+
"9:STARTED",
1082+
"10:STARTED",
1083+
"10:COMPLETED",
1084+
"9:COMPLETED",
1085+
"8:COMPLETED",
1086+
"7:COMPLETED",
1087+
"6:COMPLETED",
1088+
"5:COMPLETED",
1089+
"4:COMPLETED",
1090+
"3:COMPLETED",
1091+
"2:COMPLETED",
1092+
"1:COMPLETED",
1093+
]
1094+
1095+
assert sent == [
1096+
{
1097+
"type": "http.response.start",
1098+
"status": 200,
1099+
"headers": [(b"content-length", b"0")],
1100+
},
1101+
{"type": "http.response.body", "body": b"", "more_body": False},
1102+
]
1103+
1104+
1105+
@pytest.mark.anyio
1106+
@pytest.mark.parametrize("send_body", [True, False])
1107+
async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
1108+
async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
1109+
for _ in range(2):
1110+
msg = await receive()
1111+
while msg["type"] == "http.request":
1112+
msg = await receive()
1113+
assert msg["type"] == "http.disconnect"
1114+
await Response(b"good!")(scope, receive, send)
1115+
1116+
class MyMiddleware(BaseHTTPMiddleware):
1117+
async def dispatch(
1118+
self, request: Request, call_next: RequestResponseEndpoint
1119+
) -> Response:
1120+
return await call_next(request)
1121+
1122+
app = MyMiddleware(app_poll_disconnect)
1123+
1124+
scope = {
1125+
"type": "http",
1126+
"version": "3",
1127+
"method": "GET",
1128+
"path": "/",
1129+
}
1130+
1131+
async def receive() -> AsyncIterator[Message]:
1132+
# the key here is that we only ever send 1 htt.disconnect message
1133+
if send_body:
1134+
yield {"type": "http.request", "body": b"hello", "more_body": True}
1135+
yield {"type": "http.request", "body": b"", "more_body": False}
1136+
yield {"type": "http.disconnect"}
1137+
raise AssertionError("Should not be called, would hang") # pragma: no cover
1138+
1139+
sent: list[Message] = []
1140+
1141+
async def send(message: Message) -> None:
1142+
sent.append(message)
1143+
1144+
await app(scope, receive().__anext__, send)
1145+
1146+
assert sent == [
1147+
{
1148+
"type": "http.response.start",
1149+
"status": 200,
1150+
"headers": [(b"content-length", b"5")],
1151+
},
1152+
{"type": "http.response.body", "body": b"good!", "more_body": True},
1153+
{"type": "http.response.body", "body": b"", "more_body": False},
1154+
]

0 commit comments

Comments
 (0)