diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 32 |
1 files changed, 27 insertions, 5 deletions
@@ -148,6 +148,7 @@ if sys.platform == "win32": from _ssl import enum_certificates, enum_crls from socket import getnameinfo as _getnameinfo +from socket import SHUT_RDWR as _SHUT_RDWR from socket import socket, AF_INET, SOCK_STREAM, create_connection import base64 # for DER-to-PEM translation import traceback @@ -235,7 +236,9 @@ def match_hostname(cert, hostname): returns nothing. """ if not cert: - raise ValueError("empty or no certificate") + raise ValueError("empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED") dnsnames = [] san = cert.get('subjectAltName', ()) for key, value in san: @@ -387,9 +390,10 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0) # disallow ciphers with known vulnerabilities context.set_ciphers(_RESTRICTED_CIPHERS) - # verify certs in client mode + # verify certs and host name in client mode if purpose == Purpose.SERVER_AUTH: context.verify_mode = CERT_REQUIRED + context.check_hostname = True if cafile or capath or cadata: context.load_verify_locations(cafile, capath, cadata) elif context.verify_mode != CERT_NONE: @@ -480,6 +484,13 @@ class SSLSocket(socket): if server_side and server_hostname: raise ValueError("server_hostname can only be specified " "in client mode") + if self._context.check_hostname and not server_hostname: + if HAS_SNI: + raise ValueError("check_hostname requires server_hostname") + else: + raise ValueError("check_hostname requires server_hostname, " + "but it's not supported by your OpenSSL " + "library") self.server_side = server_side self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect @@ -522,9 +533,9 @@ class SSLSocket(socket): raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") self.do_handshake() - except OSError as x: + except (OSError, ValueError): self.close() - raise x + raise @property def context(self): @@ -751,6 +762,17 @@ class SSLSocket(socket): finally: self.settimeout(timeout) + if self.context.check_hostname: + try: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + except Exception: + self.shutdown(_SHUT_RDWR) + self.close() + raise + def _real_connect(self, addr, connect_ex): if self.server_side: raise ValueError("can't connect in server-side mode") @@ -770,7 +792,7 @@ class SSLSocket(socket): if self.do_handshake_on_connect: self.do_handshake() return rc - except OSError: + except (OSError, ValueError): self._sslobj = None raise |