Skip to content

Commit 7041757

Browse files
committed
support new TCPConnector param fingerprint
enables ssl certificate pinning
1 parent fc7cbbf commit 7041757

File tree

7 files changed

+136
-11
lines changed

7 files changed

+136
-11
lines changed

CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ CHANGES
44
0.16.0 (XX-XX-XXXX)
55
-------------------
66

7+
- Support new `fingerprint` param of TCPConnector to enable verifying
8+
ssl certificates via md5, sha1, or sha256 digest
9+
710
- Setup uploaded filename if field value is binary and transfer
811
encoding is not specified #349
912

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Olaf Conradi
3838
Paul Colomiets
3939
Philipp A.
4040
Raúl Cumplido
41+
"Required Field" <requiredfield256@gmail.com>
4142
Robert Lu
4243
Sebastian Hanula
4344
Simon Kennedy

aiohttp/connector.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010

1111
from collections import defaultdict
12+
from hashlib import md5, sha1, sha256
1213
from itertools import chain
1314
from math import ceil
1415

@@ -17,6 +18,7 @@
1718
from .errors import ServerDisconnectedError
1819
from .errors import HttpProxyError, ProxyConnectionError
1920
from .errors import ClientOSError, ClientTimeoutError
21+
from .errors import FingerprintMismatch
2022
from .helpers import BasicAuth
2123

2224

@@ -25,6 +27,12 @@
2527
PY_34 = sys.version_info >= (3, 4)
2628
PY_343 = sys.version_info >= (3, 4, 3)
2729

30+
HASHFUNC_BY_DIGESTLEN = {
31+
16: md5,
32+
20: sha1,
33+
32: sha256,
34+
}
35+
2836

2937
class Connection(object):
3038

@@ -347,13 +355,17 @@ class TCPConnector(BaseConnector):
347355
"""TCP connector.
348356
349357
:param bool verify_ssl: Set to True to check ssl certifications.
358+
:param bytes fingerprint: Pass the binary md5, sha1, or sha256
359+
digest of the expected certificate in DER format to verify
360+
the cert the server presents matches. See also
361+
https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
350362
:param bool resolve: Set to True to do DNS lookup for host name.
351363
:param family: socket address family
352364
:param args: see :class:`BaseConnector`
353365
:param kwargs: see :class:`BaseConnector`
354366
"""
355367

356-
def __init__(self, *, verify_ssl=True,
368+
def __init__(self, *, verify_ssl=True, fingerprint=None,
357369
resolve=False, family=socket.AF_INET, ssl_context=None,
358370
**kwargs):
359371
super().__init__(**kwargs)
@@ -364,6 +376,15 @@ def __init__(self, *, verify_ssl=True,
364376
"verify_ssl=False or specify ssl_context, not both.")
365377

366378
self._verify_ssl = verify_ssl
379+
380+
if fingerprint:
381+
digestlen = len(fingerprint)
382+
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
383+
if not hashfunc:
384+
raise ValueError('fingerprint has invalid length')
385+
self._hashfunc = hashfunc
386+
self._fingerprint = fingerprint
387+
367388
self._ssl_context = ssl_context
368389
self._family = family
369390
self._resolve = resolve
@@ -374,6 +395,11 @@ def verify_ssl(self):
374395
"""Do check for ssl certifications?"""
375396
return self._verify_ssl
376397

398+
@property
399+
def fingerprint(self):
400+
"""Expected ssl certificate fingerprint."""
401+
return self._fingerprint
402+
377403
@property
378404
def ssl_context(self):
379405
"""SSLContext instance for https requests.
@@ -464,11 +490,25 @@ def _create_connection(self, req):
464490

465491
for hinfo in hosts:
466492
try:
467-
return (yield from self._loop.create_connection(
468-
self._factory, hinfo['host'], hinfo['port'],
493+
host = hinfo['host']
494+
port = hinfo['port']
495+
conn = yield from self._loop.create_connection(
496+
self._factory, host, port,
469497
ssl=sslcontext, family=hinfo['family'],
470498
proto=hinfo['proto'], flags=hinfo['flags'],
471-
server_hostname=hinfo['hostname'] if sslcontext else None))
499+
server_hostname=hinfo['hostname'] if sslcontext else None)
500+
transport = conn[0]
501+
has_cert = transport.get_extra_info('sslcontext')
502+
if has_cert and self._fingerprint:
503+
sock = transport.get_extra_info('socket')
504+
# gives DER-encoded cert as a sequence of bytes (or None)
505+
cert = sock.getpeercert(binary_form=True)
506+
assert cert
507+
got = self._hashfunc(cert).digest()
508+
expected = self._fingerprint
509+
if got != expected:
510+
raise FingerprintMismatch(expected, got, host, port)
511+
return conn
472512
except OSError as e:
473513
exc = e
474514
else:

aiohttp/errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
1414
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
1515
'ClientRequestError', 'ClientResponseError',
16+
'FingerprintMismatch',
1617

1718
'WSServerHandshakeError', 'WSClientDisconnectedError')
1819

@@ -170,3 +171,18 @@ class LineLimitExceededParserError(ParserError):
170171
def __init__(self, msg, limit):
171172
super().__init__(msg)
172173
self.limit = limit
174+
175+
176+
class FingerprintMismatch(ClientConnectionError):
177+
"""SSL certificate does not match expected fingerprint."""
178+
179+
def __init__(self, expected, got, host, port):
180+
self.expected = expected
181+
self.got = got
182+
self.host = host
183+
self.port = port
184+
185+
def __repr__(self):
186+
return '<{} expected={} got={} host={} port={}>'.format(
187+
self.__class__.__name__, self.expected, self.got,
188+
self.host, self.port)

docs/client.rst

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,37 @@ By default it uses strict checks for HTTPS protocol. Certification
396396
checks can be relaxed by passing ``verify_ssl=False``::
397397

398398
>>> conn = aiohttp.TCPConnector(verify_ssl=False)
399-
>>> r = yield from aiohttp.request(
400-
... 'get', 'https://example.com', connector=conn)
399+
>>> session = aiohttp.ClientSession(connector=conn)
400+
>>> r = yield from session.get('https://example.com')
401401

402402

403403
If you need to setup custom ssl parameters (use own certification
404404
files for example) you can create a :class:`ssl.SSLContext` instance and
405405
pass it into the connector::
406406

407-
>>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
408-
>>> sslcontext.verify_mode = ssl.CERT_REQUIRED
409-
>>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt")
407+
>>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt')
410408
>>> conn = aiohttp.TCPConnector(ssl_context=sslcontext)
411-
>>> r = yield from aiohttp.request(
412-
... 'get', 'https://example.com', connector=conn)
409+
>>> session = aiohttp.ClientSession(connector=conn)
410+
>>> r = yield from session.get('https://example.com')
411+
412+
You may also verify certificates via fingerprint::
413+
414+
>>> # Attempt to connect to https://www.python.org
415+
>>> # with a pin to a bogus certificate:
416+
>>> bad_md5 = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
417+
>>> conn = aiohttp.TCPConnector(fingerprint=bad_md5)
418+
>>> session = aiohttp.ClientSession(connector=conn)
419+
>>> exc = None
420+
>>> try:
421+
... r = yield from session.get('https://www.python.org')
422+
... except FingerprintMismatch as e:
423+
... exc = e
424+
>>> exc is not None
425+
True
426+
>>> exc.expected == bad_md5
427+
True
428+
>>> exc.got # www.python.org cert's actual md5
429+
b'\xca;I\x9cuv\x8es\x138N$?\x15\xca\xcb'
413430

414431

415432
Unix domain sockets

tests/sample.crt.der

567 Bytes
Binary file not shown.

tests/test_connector.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import aiohttp
1313
from aiohttp import client
1414
from aiohttp import test_utils
15+
from aiohttp.errors import FingerprintMismatch
1516
from aiohttp.client import ClientResponse, ClientRequest
1617
from aiohttp.connector import Connection
1718

@@ -452,10 +453,57 @@ def test_cleanup3(self):
452453
def test_tcp_connector_ctor(self):
453454
conn = aiohttp.TCPConnector(loop=self.loop)
454455
self.assertTrue(conn.verify_ssl)
456+
self.assertIs(conn.fingerprint, None)
455457
self.assertFalse(conn.resolve)
456458
self.assertEqual(conn.family, socket.AF_INET)
457459
self.assertEqual(conn.resolved_hosts, {})
458460

461+
def test_tcp_connector_ctor_fingerprint_valid(self):
462+
valid = b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06='
463+
conn = aiohttp.TCPConnector(loop=self.loop, fingerprint=valid)
464+
self.assertEqual(conn.fingerprint, valid)
465+
466+
def test_tcp_connector_fingerprint_invalid(self):
467+
invalid = b'\x00'
468+
with self.assertRaises(ValueError):
469+
aiohttp.TCPConnector(loop=self.loop, fingerprint=invalid)
470+
471+
def test_tcp_connector_fingerprint(self):
472+
# The even-index fingerprints below are "expect success" cases
473+
# for ./sample.crt.der, the cert presented by test_utils.run_server.
474+
# The odd-index fingerprints are "expect fail" cases.
475+
testcases = (
476+
# md5
477+
b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=',
478+
b'\x00' * 16,
479+
480+
# sha1
481+
b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9',
482+
b'\x00' * 20,
483+
484+
# sha256
485+
b'0\x9a\xc9D\x83\xdc\x91\'\x88\x91\x11\xa1d\x97\xfd\xcb~7U\x14D@L'
486+
b'\x11\xab\x99\xa8\xae\xb7\x14\xee\x8b',
487+
b'\x00' * 32,
488+
)
489+
for i, fingerprint in enumerate(testcases):
490+
expect_fail = i % 2
491+
conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False,
492+
fingerprint=fingerprint)
493+
with test_utils.run_server(self.loop, use_ssl=True) as httpd:
494+
coro = client.request('get', httpd.url('method', 'get'),
495+
connector=conn, loop=self.loop)
496+
if expect_fail:
497+
with self.assertRaises(FingerprintMismatch) as cm:
498+
self.loop.run_until_complete(coro)
499+
exc = cm.exception
500+
self.assertEqual(exc.expected, fingerprint)
501+
# the previous test case should be what we actually got
502+
self.assertEqual(exc.got, testcases[i-1])
503+
else:
504+
# should not raise
505+
self.loop.run_until_complete(coro)
506+
459507
def test_tcp_connector_clear_resolved_hosts(self):
460508
conn = aiohttp.TCPConnector(loop=self.loop)
461509
info = object()

0 commit comments

Comments
 (0)