diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 32 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 62 |
2 files changed, 89 insertions, 5 deletions
@@ -148,6 +148,7 @@ if sys.platform == "win32": from _ssl import enum_certificates, enum_crls from socket import getnameinfo as _getnameinfo +from socket import SHUT_RDWR as _SHUT_RDWR from socket import socket, AF_INET, SOCK_STREAM, create_connection import base64 # for DER-to-PEM translation import traceback @@ -235,7 +236,9 @@ def match_hostname(cert, hostname): returns nothing. """ if not cert: - raise ValueError("empty or no certificate") + raise ValueError("empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED") dnsnames = [] san = cert.get('subjectAltName', ()) for key, value in san: @@ -387,9 +390,10 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0) # disallow ciphers with known vulnerabilities context.set_ciphers(_RESTRICTED_CIPHERS) - # verify certs in client mode + # verify certs and host name in client mode if purpose == Purpose.SERVER_AUTH: context.verify_mode = CERT_REQUIRED + context.check_hostname = True if cafile or capath or cadata: context.load_verify_locations(cafile, capath, cadata) elif context.verify_mode != CERT_NONE: @@ -480,6 +484,13 @@ class SSLSocket(socket): if server_side and server_hostname: raise ValueError("server_hostname can only be specified " "in client mode") + if self._context.check_hostname and not server_hostname: + if HAS_SNI: + raise ValueError("check_hostname requires server_hostname") + else: + raise ValueError("check_hostname requires server_hostname, " + "but it's not supported by your OpenSSL " + "library") self.server_side = server_side self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect @@ -522,9 +533,9 @@ class SSLSocket(socket): raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") self.do_handshake() - except OSError as x: + except (OSError, ValueError): self.close() - raise x + raise @property def context(self): @@ -751,6 +762,17 @@ class SSLSocket(socket): finally: self.settimeout(timeout) + if self.context.check_hostname: + try: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + except Exception: + self.shutdown(_SHUT_RDWR) + self.close() + raise + def _real_connect(self, addr, connect_ex): if self.server_side: raise ValueError("can't connect in server-side mode") @@ -770,7 +792,7 @@ class SSLSocket(socket): if self.do_handshake_on_connect: self.do_handshake() return rc - except OSError: + except (OSError, ValueError): self._sslobj = None raise diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index afec72a..ed263c3 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1003,6 +1003,7 @@ class ContextTests(unittest.TestCase): ctx = ssl.create_default_context() self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) with open(SIGNING_CA) as f: @@ -1022,6 +1023,7 @@ class ContextTests(unittest.TestCase): ctx = ssl._create_stdlib_context() self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertFalse(ctx.check_hostname) self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1) @@ -1040,6 +1042,28 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + def test_check_hostname(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertFalse(ctx.check_hostname) + + # Requires CERT_REQUIRED or CERT_OPTIONAL + with self.assertRaises(ValueError): + ctx.check_hostname = True + ctx.verify_mode = ssl.CERT_REQUIRED + self.assertFalse(ctx.check_hostname) + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + ctx.verify_mode = ssl.CERT_OPTIONAL + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + # Cannot set CERT_NONE with check_hostname enabled + with self.assertRaises(ValueError): + ctx.verify_mode = ssl.CERT_NONE + ctx.check_hostname = False + self.assertFalse(ctx.check_hostname) + class SSLErrorTests(unittest.TestCase): @@ -1930,6 +1954,44 @@ else: cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") + def test_check_hostname(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) + + # correct hostname should verify + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="localhost") as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # incorrect hostname should raise an exception + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex(ssl.CertificateError, + "hostname 'invalid' doesn't match 'localhost'"): + s.connect((HOST, server.port)) + + # missing server_hostname arg should cause an exception, too + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + context.wrap_socket(s) + def test_empty_cert(self): """Connecting with an empty cert file""" bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, |