55from typing import (
66 Any ,
77 AsyncGenerator ,
8+ AsyncIterator ,
89 Generator ,
910)
1011
1617from starlette .background import BackgroundTask
1718from starlette .middleware import Middleware , _MiddlewareClass
1819from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
19- from starlette .requests import Request
20+ from starlette .requests import ClientDisconnect , Request
2021from starlette .responses import PlainTextResponse , Response , StreamingResponse
2122from starlette .routing import Route , WebSocketRoute
2223from starlette .testclient import TestClient
@@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse:
260261@pytest .mark .anyio
261262async def test_run_background_tasks_even_if_client_disconnects () -> None :
262263 # test for https://github.com/encode/starlette/issues/1438
263- request_body_sent = False
264264 response_complete = anyio .Event ()
265265 background_task_run = anyio .Event ()
266266
@@ -293,13 +293,7 @@ async def passthrough(
293293 }
294294
295295 async def receive () -> Message :
296- nonlocal request_body_sent
297- if not request_body_sent :
298- request_body_sent = True
299- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
300- # We simulate a client that disconnects immediately after receiving the response
301- await response_complete .wait ()
302- return {"type" : "http.disconnect" }
296+ raise NotImplementedError ("Should not be called!" ) # pragma: no cover
303297
304298 async def send (message : Message ) -> None :
305299 if message ["type" ] == "http.response.body" :
@@ -313,7 +307,6 @@ async def send(message: Message) -> None:
313307
314308@pytest .mark .anyio
315309async def test_do_not_block_on_background_tasks () -> None :
316- request_body_sent = False
317310 response_complete = anyio .Event ()
318311 events : list [str | Message ] = []
319312
@@ -345,12 +338,7 @@ async def passthrough(
345338 }
346339
347340 async def receive () -> Message :
348- nonlocal request_body_sent
349- if not request_body_sent :
350- request_body_sent = True
351- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
352- await response_complete .wait ()
353- return {"type" : "http.disconnect" }
341+ raise NotImplementedError ("Should not be called!" ) # pragma: no cover
354342
355343 async def send (message : Message ) -> None :
356344 if message ["type" ] == "http.response.body" :
@@ -379,7 +367,6 @@ async def send(message: Message) -> None:
379367@pytest .mark .anyio
380368async def test_run_context_manager_exit_even_if_client_disconnects () -> None :
381369 # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
382- request_body_sent = False
383370 response_complete = anyio .Event ()
384371 context_manager_exited = anyio .Event ()
385372
@@ -424,13 +411,7 @@ async def passthrough(
424411 }
425412
426413 async def receive () -> Message :
427- nonlocal request_body_sent
428- if not request_body_sent :
429- request_body_sent = True
430- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
431- # We simulate a client that disconnects immediately after receiving the response
432- await response_complete .wait ()
433- return {"type" : "http.disconnect" }
414+ raise NotImplementedError ("Should not be called!" ) # pragma: no cover
434415
435416 async def send (message : Message ) -> None :
436417 if message ["type" ] == "http.response.body" :
@@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
778759 yield {"type" : "http.request" , "body" : b"1" , "more_body" : True }
779760 yield {"type" : "http.request" , "body" : b"2" , "more_body" : True }
780761 yield {"type" : "http.request" , "body" : b"3" }
781- await anyio .sleep (float ("inf" ))
762+ raise AssertionError ( # pragma: no cover
763+ "Should not be called, no need to poll for disconnect"
764+ )
782765
783766 sent : list [Message ] = []
784767
@@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response:
10331016 resp .raise_for_status ()
10341017
10351018 assert bodies == [b"Hello, World!-foo" ]
1019+
1020+
1021+ @pytest .mark .anyio
1022+ async def test_multiple_middlewares_stacked_client_disconnected () -> None :
1023+ class MyMiddleware (BaseHTTPMiddleware ):
1024+ def __init__ (self , app : ASGIApp , version : int , events : list [str ]) -> None :
1025+ self .version = version
1026+ self .events = events
1027+ super ().__init__ (app )
1028+
1029+ async def dispatch (
1030+ self , request : Request , call_next : RequestResponseEndpoint
1031+ ) -> Response :
1032+ self .events .append (f"{ self .version } :STARTED" )
1033+ res = await call_next (request )
1034+ self .events .append (f"{ self .version } :COMPLETED" )
1035+ return res
1036+
1037+ async def sleepy (request : Request ) -> Response :
1038+ try :
1039+ await request .body ()
1040+ except ClientDisconnect :
1041+ pass
1042+ else : # pragma: no cover
1043+ raise AssertionError ("Should have raised ClientDisconnect" )
1044+ return Response (b"" )
1045+
1046+ events : list [str ] = []
1047+
1048+ app = Starlette (
1049+ routes = [Route ("/" , sleepy )],
1050+ middleware = [
1051+ Middleware (MyMiddleware , version = _ + 1 , events = events ) for _ in range (10 )
1052+ ],
1053+ )
1054+
1055+ scope = {
1056+ "type" : "http" ,
1057+ "version" : "3" ,
1058+ "method" : "GET" ,
1059+ "path" : "/" ,
1060+ }
1061+
1062+ async def receive () -> AsyncIterator [Message ]:
1063+ yield {"type" : "http.disconnect" }
1064+
1065+ sent : list [Message ] = []
1066+
1067+ async def send (message : Message ) -> None :
1068+ sent .append (message )
1069+
1070+ await app (scope , receive ().__anext__ , send )
1071+
1072+ assert events == [
1073+ "1:STARTED" ,
1074+ "2:STARTED" ,
1075+ "3:STARTED" ,
1076+ "4:STARTED" ,
1077+ "5:STARTED" ,
1078+ "6:STARTED" ,
1079+ "7:STARTED" ,
1080+ "8:STARTED" ,
1081+ "9:STARTED" ,
1082+ "10:STARTED" ,
1083+ "10:COMPLETED" ,
1084+ "9:COMPLETED" ,
1085+ "8:COMPLETED" ,
1086+ "7:COMPLETED" ,
1087+ "6:COMPLETED" ,
1088+ "5:COMPLETED" ,
1089+ "4:COMPLETED" ,
1090+ "3:COMPLETED" ,
1091+ "2:COMPLETED" ,
1092+ "1:COMPLETED" ,
1093+ ]
1094+
1095+ assert sent == [
1096+ {
1097+ "type" : "http.response.start" ,
1098+ "status" : 200 ,
1099+ "headers" : [(b"content-length" , b"0" )],
1100+ },
1101+ {"type" : "http.response.body" , "body" : b"" , "more_body" : False },
1102+ ]
1103+
1104+
1105+ @pytest .mark .anyio
1106+ @pytest .mark .parametrize ("send_body" , [True , False ])
1107+ async def test_poll_for_disconnect_repeated (send_body : bool ) -> None :
1108+ async def app_poll_disconnect (scope : Scope , receive : Receive , send : Send ) -> None :
1109+ for _ in range (2 ):
1110+ msg = await receive ()
1111+ while msg ["type" ] == "http.request" :
1112+ msg = await receive ()
1113+ assert msg ["type" ] == "http.disconnect"
1114+ await Response (b"good!" )(scope , receive , send )
1115+
1116+ class MyMiddleware (BaseHTTPMiddleware ):
1117+ async def dispatch (
1118+ self , request : Request , call_next : RequestResponseEndpoint
1119+ ) -> Response :
1120+ return await call_next (request )
1121+
1122+ app = MyMiddleware (app_poll_disconnect )
1123+
1124+ scope = {
1125+ "type" : "http" ,
1126+ "version" : "3" ,
1127+ "method" : "GET" ,
1128+ "path" : "/" ,
1129+ }
1130+
1131+ async def receive () -> AsyncIterator [Message ]:
1132+ # the key here is that we only ever send 1 htt.disconnect message
1133+ if send_body :
1134+ yield {"type" : "http.request" , "body" : b"hello" , "more_body" : True }
1135+ yield {"type" : "http.request" , "body" : b"" , "more_body" : False }
1136+ yield {"type" : "http.disconnect" }
1137+ raise AssertionError ("Should not be called, would hang" ) # pragma: no cover
1138+
1139+ sent : list [Message ] = []
1140+
1141+ async def send (message : Message ) -> None :
1142+ sent .append (message )
1143+
1144+ await app (scope , receive ().__anext__ , send )
1145+
1146+ assert sent == [
1147+ {
1148+ "type" : "http.response.start" ,
1149+ "status" : 200 ,
1150+ "headers" : [(b"content-length" , b"5" )],
1151+ },
1152+ {"type" : "http.response.body" , "body" : b"good!" , "more_body" : True },
1153+ {"type" : "http.response.body" , "body" : b"" , "more_body" : False },
1154+ ]
0 commit comments