Skip to content

Commit ec6cfae

Browse files
authored
Implement base protocol class (#2986)
1 parent e0378bd commit ec6cfae

File tree

5 files changed

+246
-16
lines changed

5 files changed

+246
-16
lines changed

CHANGES/2986.feature

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Implement base protocol class to avoid a dependency from internal
2+
``asyncio.streams.FlowControlMixin``

aiohttp/base_protocol.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
3+
from .log import internal_logger
4+
5+
6+
class BaseProtocol(asyncio.Protocol):
7+
def __init__(self, loop=None):
8+
if loop is None:
9+
self._loop = asyncio.get_event_loop()
10+
else:
11+
self._loop = loop
12+
self._paused = False
13+
self._drain_waiter = None
14+
self._connection_lost = False
15+
self.transport = None
16+
17+
def pause_writing(self):
18+
assert not self._paused
19+
self._paused = True
20+
if self._loop.get_debug():
21+
internal_logger.debug("%r pauses writing", self)
22+
23+
def resume_writing(self):
24+
assert self._paused
25+
self._paused = False
26+
if self._loop.get_debug():
27+
internal_logger.debug("%r resumes writing", self)
28+
29+
waiter = self._drain_waiter
30+
if waiter is not None:
31+
self._drain_waiter = None
32+
if not waiter.done():
33+
waiter.set_result(None)
34+
35+
def connection_made(self, transport):
36+
self.transport = transport
37+
38+
def connection_lost(self, exc):
39+
self._connection_lost = True
40+
# Wake up the writer if currently paused.
41+
self.transport = None
42+
if not self._paused:
43+
return
44+
waiter = self._drain_waiter
45+
if waiter is None:
46+
return
47+
self._drain_waiter = None
48+
if waiter.done():
49+
return
50+
if exc is None:
51+
waiter.set_result(None)
52+
else:
53+
waiter.set_exception(exc)
54+
55+
async def _drain_helper(self):
56+
if self._connection_lost:
57+
raise ConnectionResetError('Connection lost')
58+
if not self._paused:
59+
return
60+
waiter = self._drain_waiter
61+
assert waiter is None or waiter.cancelled()
62+
waiter = self._loop.create_future()
63+
self._drain_waiter = waiter
64+
await waiter

aiohttp/client_proto.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
import asyncio
2-
import asyncio.streams
31
from contextlib import suppress
42

3+
from .base_protocol import BaseProtocol
54
from .client_exceptions import (ClientOSError, ClientPayloadError,
65
ServerDisconnectedError)
76
from .http import HttpResponseParser
87
from .streams import EMPTY_PAYLOAD, DataQueue
98

109

11-
class ResponseHandler(DataQueue, asyncio.streams.FlowControlMixin):
10+
class ResponseHandler(BaseProtocol, DataQueue):
1211
"""Helper class to adapt between Protocol and StreamReader."""
1312

1413
def __init__(self, *, loop=None):
15-
asyncio.streams.FlowControlMixin.__init__(self, loop=loop)
14+
BaseProtocol.__init__(self, loop=loop)
1615
DataQueue.__init__(self, loop=loop)
1716

18-
self.transport = None
1917
self._should_close = False
2018

2119
self._message = None
@@ -56,9 +54,6 @@ def close(self):
5654
def is_connected(self):
5755
return self.transport is not None
5856

59-
def connection_made(self, transport):
60-
self.transport = transport
61-
6257
def connection_lost(self, exc):
6358
if self._payload_parser is not None:
6459
with suppress(Exception):
@@ -81,7 +76,6 @@ def connection_lost(self, exc):
8176
# we do it anyway below
8277
self.set_exception(exc)
8378

84-
self.transport = None
8579
self._should_close = True
8680
self._parser = None
8781
self._message = None

aiohttp/web_protocol.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import yarl
1111

1212
from . import helpers, http
13+
from .base_protocol import BaseProtocol
1314
from .helpers import CeilTimeout
1415
from .http import HttpProcessingError, HttpRequestParser, StreamWriter
1516
from .log import access_logger, server_logger
@@ -35,7 +36,7 @@ class PayloadAccessError(Exception):
3536
"""Payload was accesed after responce was sent."""
3637

3738

38-
class RequestHandler(asyncio.streams.FlowControlMixin, asyncio.Protocol):
39+
class RequestHandler(BaseProtocol):
3940
"""HTTP protocol implementation.
4041
4142
RequestHandler handles incoming HTTP request. It reads request line,
@@ -93,8 +94,6 @@ def __init__(self, manager, *, loop=None,
9394

9495
super().__init__(loop=loop)
9596

96-
self._loop = loop if loop is not None else asyncio.get_event_loop()
97-
9897
self._manager = manager
9998
self._request_handler = manager.request_handler
10099
self._request_factory = manager.request_factory
@@ -121,7 +120,6 @@ def __init__(self, manager, *, loop=None,
121120
max_headers=max_headers,
122121
payload_exception=RequestPayloadError)
123122

124-
self.transport = None
125123
self._reading_paused = False
126124

127125
self.logger = logger
@@ -177,8 +175,6 @@ async def shutdown(self, timeout=15.0):
177175
def connection_made(self, transport):
178176
super().connection_made(transport)
179177

180-
self.transport = transport
181-
182178
if self._tcp_keepalive:
183179
tcp_keepalive(transport)
184180

@@ -196,7 +192,6 @@ def connection_lost(self, exc):
196192
self._request_factory = None
197193
self._request_handler = None
198194
self._request_parser = None
199-
self.transport = None
200195

201196
if self._keepalive_handle is not None:
202197
self._keepalive_handle.cancel()

tests/test_base_protocol.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import asyncio
2+
from contextlib import suppress
3+
from unittest import mock
4+
5+
import pytest
6+
7+
from aiohttp.base_protocol import BaseProtocol
8+
9+
10+
def test_loop(loop):
11+
asyncio.set_event_loop(None)
12+
pr = BaseProtocol(loop=loop)
13+
assert pr._loop is loop
14+
15+
16+
def test_default_loop(loop):
17+
asyncio.set_event_loop(loop)
18+
pr = BaseProtocol()
19+
assert pr._loop is loop
20+
21+
22+
def test_pause_writing(loop):
23+
pr = BaseProtocol(loop=loop)
24+
assert not pr._paused
25+
pr.pause_writing()
26+
assert pr._paused
27+
28+
29+
def test_resume_writing_no_waiters(loop):
30+
pr = BaseProtocol(loop=loop)
31+
pr.pause_writing()
32+
assert pr._paused
33+
pr.resume_writing()
34+
assert not pr._paused
35+
36+
37+
def test_connection_made(loop):
38+
pr = BaseProtocol(loop=loop)
39+
tr = mock.Mock()
40+
assert pr.transport is None
41+
pr.connection_made(tr)
42+
assert pr.transport is not None
43+
44+
45+
def test_connection_lost_not_paused(loop):
46+
pr = BaseProtocol(loop=loop)
47+
tr = mock.Mock()
48+
pr.connection_made(tr)
49+
assert not pr._connection_lost
50+
pr.connection_lost(None)
51+
assert pr.transport is None
52+
assert pr._connection_lost
53+
54+
55+
def test_connection_lost_paused_without_waiter(loop):
56+
pr = BaseProtocol(loop=loop)
57+
tr = mock.Mock()
58+
pr.connection_made(tr)
59+
assert not pr._connection_lost
60+
pr.pause_writing()
61+
pr.connection_lost(None)
62+
assert pr.transport is None
63+
assert pr._connection_lost
64+
65+
66+
async def test_drain_lost(loop):
67+
pr = BaseProtocol(loop=loop)
68+
tr = mock.Mock()
69+
pr.connection_made(tr)
70+
pr.connection_lost(None)
71+
with pytest.raises(ConnectionResetError):
72+
await pr._drain_helper()
73+
74+
75+
async def test_drain_not_paused(loop):
76+
pr = BaseProtocol(loop=loop)
77+
tr = mock.Mock()
78+
pr.connection_made(tr)
79+
assert pr._drain_waiter is None
80+
await pr._drain_helper()
81+
assert pr._drain_waiter is None
82+
83+
84+
async def test_resume_drain_waited(loop):
85+
pr = BaseProtocol(loop=loop)
86+
tr = mock.Mock()
87+
pr.connection_made(tr)
88+
pr.pause_writing()
89+
90+
t = loop.create_task(pr._drain_helper())
91+
await asyncio.sleep(0)
92+
93+
assert pr._drain_waiter is not None
94+
pr.resume_writing()
95+
assert (await t) is None
96+
assert pr._drain_waiter is None
97+
98+
99+
async def test_lost_drain_waited_ok(loop):
100+
pr = BaseProtocol(loop=loop)
101+
tr = mock.Mock()
102+
pr.connection_made(tr)
103+
pr.pause_writing()
104+
105+
t = loop.create_task(pr._drain_helper())
106+
await asyncio.sleep(0)
107+
108+
assert pr._drain_waiter is not None
109+
pr.connection_lost(None)
110+
assert (await t) is None
111+
assert pr._drain_waiter is None
112+
113+
114+
async def test_lost_drain_waited_exception(loop):
115+
pr = BaseProtocol(loop=loop)
116+
tr = mock.Mock()
117+
pr.connection_made(tr)
118+
pr.pause_writing()
119+
120+
t = loop.create_task(pr._drain_helper())
121+
await asyncio.sleep(0)
122+
123+
assert pr._drain_waiter is not None
124+
exc = RuntimeError()
125+
pr.connection_lost(exc)
126+
with pytest.raises(RuntimeError) as cm:
127+
await t
128+
assert cm.value is exc
129+
assert pr._drain_waiter is None
130+
131+
132+
async def test_lost_drain_cancelled(loop):
133+
pr = BaseProtocol(loop=loop)
134+
tr = mock.Mock()
135+
pr.connection_made(tr)
136+
pr.pause_writing()
137+
138+
fut = loop.create_future()
139+
140+
async def wait():
141+
fut.set_result(None)
142+
await pr._drain_helper()
143+
144+
t = loop.create_task(wait())
145+
await fut
146+
t.cancel()
147+
148+
assert pr._drain_waiter is not None
149+
pr.connection_lost(None)
150+
with suppress(asyncio.CancelledError):
151+
await t
152+
assert pr._drain_waiter is None
153+
154+
155+
async def test_resume_drain_cancelled(loop):
156+
pr = BaseProtocol(loop=loop)
157+
tr = mock.Mock()
158+
pr.connection_made(tr)
159+
pr.pause_writing()
160+
161+
fut = loop.create_future()
162+
163+
async def wait():
164+
fut.set_result(None)
165+
await pr._drain_helper()
166+
167+
t = loop.create_task(wait())
168+
await fut
169+
t.cancel()
170+
171+
assert pr._drain_waiter is not None
172+
pr.resume_writing()
173+
with suppress(asyncio.CancelledError):
174+
await t
175+
assert pr._drain_waiter is None

0 commit comments

Comments
 (0)