diff --git a/aiohttp/websocket_client.py b/aiohttp/websocket_client.py index 925614ff45f..706cc50a5fd 100644 --- a/aiohttp/websocket_client.py +++ b/aiohttp/websocket_client.py @@ -35,7 +35,7 @@ class MsgType(IntEnum): @asyncio.coroutine def ws_connect(url, protocols=(), timeout=10.0, connector=None, - autoclose=True, autoping=True, loop=None): + response_class=None, autoclose=True, autoping=True, loop=None): """Initiate websocket connection.""" if loop is None: loop = asyncio.get_event_loop() @@ -87,7 +87,10 @@ def ws_connect(url, protocols=(), timeout=10.0, connector=None, reader = resp.connection.reader.set_parser(WebSocketParser) writer = WebSocketWriter(resp.connection.writer, use_mask=True) - return ClientWebSocketResponse( + if response_class is None: + response_class = ClientWebSocketResponse + + return response_class( reader, writer, protocol, resp, timeout, autoclose, autoping, loop) diff --git a/docs/client_websockets.rst b/docs/client_websockets.rst index 53a5be774fa..9bd6c4e21ab 100644 --- a/docs/client_websockets.rst +++ b/docs/client_websockets.rst @@ -46,7 +46,7 @@ ClientWebSocketResponse To connect to a websocket server you have to use the `aiohttp.ws_connect()` function, do not create an instance of class :class:`ClientWebSocketResponse` manually. -.. py:function:: ws_connect(url, protocols=(), connector=None, autoclose=True, autoping=True, loop=None) +.. py:function:: ws_connect(url, protocols=(), connector=None, response_class=None, autoclose=True, autoping=True, loop=None) This function creates a websocket connection, checks the response and returns a :class:`ClientWebSocketResponse` object. In case of failure @@ -58,6 +58,8 @@ do not create an instance of class :class:`ClientWebSocketResponse` manually. :param obj connector: object :class:`TCPConnector` + :param response_class: (optional) Custom Response class implementation. + :param bool autoclose: automatically close websocket connection on close message from server. if `autoclose` is False them close procedure has to be handled manually diff --git a/tests/test_websocket_client.py b/tests/test_websocket_client.py index 08ccf98533f..38ec9493cf8 100644 --- a/tests/test_websocket_client.py +++ b/tests/test_websocket_client.py @@ -48,6 +48,33 @@ def test_ws_connect(self, m_client, m_os): self.assertIsInstance(res, websocket_client.ClientWebSocketResponse) self.assertEqual(res.protocol, 'chat') + @mock.patch('aiohttp.websocket_client.os') + @mock.patch('aiohttp.websocket_client.client') + def test_ws_connect_custom_response(self, m_client, m_os): + + class CustomResponse(websocket_client.ClientWebSocketResponse): + def read(self, decode=False): + return 'customized!' + + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: self.ws_key, + } + m_os.urandom.return_value = self.key_data + m_client.request.return_value = asyncio.Future(loop=self.loop) + m_client.request.return_value.set_result(resp) + + res = self.loop.run_until_complete( + websocket_client.ws_connect( + 'http://test.org', + response_class=CustomResponse, + loop=self.loop)) + + self.assertEqual(res.read(), 'customized!') + @mock.patch('aiohttp.websocket_client.os') @mock.patch('aiohttp.websocket_client.client') def test_ws_connect_global_loop(self, m_client, m_os):