diff options
Diffstat (limited to 'Lib/ssl.py')
| -rw-r--r-- | Lib/ssl.py | 172 | 
1 files changed, 137 insertions, 35 deletions
| @@ -60,10 +60,25 @@ import re  import _ssl             # if we can't import it, let the error propagate  from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, SSLError +from _ssl import _SSLContext +from _ssl import ( +    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, +    SSLSyscallError, SSLEOFError, +    )  from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED -from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 -from _ssl import RAND_status, RAND_egd, RAND_add +from _ssl import ( +    OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1, +    OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE +    ) +try: +    from _ssl import OP_NO_COMPRESSION +except ImportError: +    pass +try: +    from _ssl import OP_SINGLE_ECDH_USE +except ImportError: +    pass +from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes  from _ssl import (      SSL_ERROR_ZERO_RETURN,      SSL_ERROR_WANT_READ, @@ -75,8 +90,9 @@ from _ssl import (      SSL_ERROR_EOF,      SSL_ERROR_INVALID_ERROR_CODE,      ) -from _ssl import HAS_SNI -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN +from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, +                  PROTOCOL_TLSv1)  from _ssl import _OPENSSL_API_VERSION  _PROTOCOL_NAMES = { @@ -94,11 +110,16 @@ else:  from socket import getnameinfo as _getnameinfo  from socket import error as socket_error -from socket import socket, AF_INET, SOCK_STREAM +from socket import socket, AF_INET, SOCK_STREAM, create_connection  import base64        # for DER-to-PEM translation  import traceback  import errno +if _ssl.HAS_TLS_UNIQUE: +    CHANNEL_BINDING_TYPES = ['tls-unique'] +else: +    CHANNEL_BINDING_TYPES = [] +  # Disable weak or insecure ciphers by default  # (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')  _DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2' @@ -108,31 +129,59 @@ class CertificateError(ValueError):      pass -def _dnsname_to_pat(dn, max_wildcards=1): +def _dnsname_match(dn, hostname, max_wildcards=1): +    """Matching according to RFC 6125, section 6.4.3 + +    http://tools.ietf.org/html/rfc6125#section-6.4.3 +    """      pats = [] -    for frag in dn.split(r'.'): -        if frag.count('*') > max_wildcards: -            # Issue #17980: avoid denials of service by refusing more -            # than one wildcard per fragment.  A survery of established -            # policy among SSL implementations showed it to be a -            # reasonable choice. -            raise CertificateError( -                "too many wildcards in certificate DNS name: " + repr(dn)) -        if frag == '*': -            # When '*' is a fragment by itself, it matches a non-empty dotless -            # fragment. -            pats.append('[^.]+') -        else: -            # Otherwise, '*' matches any dotless fragment. -            frag = re.escape(frag) -            pats.append(frag.replace(r'\*', '[^.]*')) -    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) +    if not dn: +        return False + +    leftmost, *remainder = dn.split(r'.') + +    wildcards = leftmost.count('*') +    if wildcards > max_wildcards: +        # Issue #17980: avoid denials of service by refusing more +        # than one wildcard per fragment.  A survery of established +        # policy among SSL implementations showed it to be a +        # reasonable choice. +        raise CertificateError( +            "too many wildcards in certificate DNS name: " + repr(dn)) + +    # speed up common case w/o wildcards +    if not wildcards: +        return dn.lower() == hostname.lower() + +    # RFC 6125, section 6.4.3, subitem 1. +    # The client SHOULD NOT attempt to match a presented identifier in which +    # the wildcard character comprises a label other than the left-most label. +    if leftmost == '*': +        # When '*' is a fragment by itself, it matches a non-empty dotless +        # fragment. +        pats.append('[^.]+') +    elif leftmost.startswith('xn--') or hostname.startswith('xn--'): +        # RFC 6125, section 6.4.3, subitem 3. +        # The client SHOULD NOT attempt to match a presented identifier +        # where the wildcard character is embedded within an A-label or +        # U-label of an internationalized domain name. +        pats.append(re.escape(leftmost)) +    else: +        # Otherwise, '*' matches any dotless string, e.g. www* +        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + +    # add the remaining fragments, ignore any wildcards +    for frag in remainder: +        pats.append(re.escape(frag)) + +    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) +    return pat.match(hostname)  def match_hostname(cert, hostname):      """Verify that *cert* (in decoded format as returned by -    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules -    are mostly followed, but IP addresses are not accepted for *hostname*. +    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125 +    rules are followed, but IP addresses are not accepted for *hostname*.      CertificateError is raised on failure. On success, the function      returns nothing. @@ -143,7 +192,7 @@ def match_hostname(cert, hostname):      san = cert.get('subjectAltName', ())      for key, value in san:          if key == 'DNS': -            if _dnsname_to_pat(value).match(hostname): +            if _dnsname_match(value, hostname):                  return              dnsnames.append(value)      if not dnsnames: @@ -154,7 +203,7 @@ def match_hostname(cert, hostname):                  # XXX according to RFC 2818, the most specific Common Name                  # must be used.                  if key == 'commonName': -                    if _dnsname_to_pat(value).match(hostname): +                    if _dnsname_match(value, hostname):                          return                      dnsnames.append(value)      if len(dnsnames) > 1: @@ -195,6 +244,17 @@ class SSLContext(_SSLContext):                           server_hostname=server_hostname,                           _context=self) +    def set_npn_protocols(self, npn_protocols): +        protos = bytearray() +        for protocol in npn_protocols: +            b = bytes(protocol, 'ascii') +            if len(b) == 0 or len(b) > 255: +                raise SSLError('NPN protocols must be 1 to 255 in length') +            protos.append(len(b)) +            protos.extend(b) + +        self._set_npn_protocols(protos) +  class SSLSocket(socket):      """This class implements a subtype of socket.socket that wraps @@ -206,7 +266,7 @@ class SSLSocket(socket):                   ssl_version=PROTOCOL_SSLv23, ca_certs=None,                   do_handshake_on_connect=True,                   family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, -                 suppress_ragged_eofs=True, ciphers=None, +                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,                   server_hostname=None,                   _context=None): @@ -226,6 +286,8 @@ class SSLSocket(socket):                  self.context.load_verify_locations(ca_certs)              if certfile:                  self.context.load_cert_chain(certfile, keyfile) +            if npn_protocols: +                self.context.set_npn_protocols(npn_protocols)              if ciphers:                  self.context.set_ciphers(ciphers)              self.keyfile = keyfile @@ -326,6 +388,13 @@ class SSLSocket(socket):          self._checkClosed()          return self._sslobj.peer_certificate(binary_form) +    def selected_npn_protocol(self): +        self._checkClosed() +        if not self._sslobj or not _ssl.HAS_NPN: +            return None +        else: +            return self._sslobj.selected_npn_protocol() +      def cipher(self):          self._checkClosed()          if not self._sslobj: @@ -333,6 +402,13 @@ class SSLSocket(socket):          else:              return self._sslobj.cipher() +    def compression(self): +        self._checkClosed() +        if not self._sslobj: +            return None +        else: +            return self._sslobj.compression() +      def send(self, data, flags=0):          self._checkClosed()          if self._sslobj: @@ -365,6 +441,12 @@ class SSLSocket(socket):          else:              return socket.sendto(self, data, flags_or_addr, addr) +    def sendmsg(self, *args, **kwargs): +        # Ensure programs don't send data unencrypted if they try to +        # use this method. +        raise NotImplementedError("sendmsg not allowed on instances of %s" % +                                  self.__class__) +      def sendall(self, data, flags=0):          self._checkClosed()          if self._sslobj: @@ -423,6 +505,14 @@ class SSLSocket(socket):          else:              return socket.recvfrom_into(self, buffer, nbytes, flags) +    def recvmsg(self, *args, **kwargs): +        raise NotImplementedError("recvmsg not allowed on instances of %s" % +                                  self.__class__) + +    def recvmsg_into(self, *args, **kwargs): +        raise NotImplementedError("recvmsg_into not allowed on instances of " +                                  "%s" % self.__class__) +      def pending(self):          self._checkClosed()          if self._sslobj: @@ -504,16 +594,28 @@ class SSLSocket(socket):                      server_side=True)          return newsock, addr -    def __del__(self): -        # sys.stderr.write("__del__ on %s\n" % repr(self)) -        self._real_close() +    def get_channel_binding(self, cb_type="tls-unique"): +        """Get channel binding data for current connection.  Raise ValueError +        if the requested `cb_type` is not supported.  Return bytes of the data +        or None if the data is not available (e.g. before the handshake). +        """ +        if cb_type not in CHANNEL_BINDING_TYPES: +            raise ValueError("Unsupported channel binding type") +        if cb_type != "tls-unique": +            raise NotImplementedError( +                            "{0} channel binding type not implemented" +                            .format(cb_type)) +        if self._sslobj is None: +            return None +        return self._sslobj.tls_unique_cb()  def wrap_socket(sock, keyfile=None, certfile=None,                  server_side=False, cert_reqs=CERT_NONE,                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,                  do_handshake_on_connect=True, -                suppress_ragged_eofs=True, ciphers=None): +                suppress_ragged_eofs=True, +                ciphers=None):      return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,                       server_side=server_side, cert_reqs=cert_reqs, @@ -568,9 +670,9 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):          cert_reqs = CERT_REQUIRED      else:          cert_reqs = CERT_NONE -    s = wrap_socket(socket(), ssl_version=ssl_version, +    s = create_connection(addr) +    s = wrap_socket(s, ssl_version=ssl_version,                      cert_reqs=cert_reqs, ca_certs=ca_certs) -    s.connect(addr)      dercert = s.getpeercert(True)      s.close()      return DER_cert_to_PEM_cert(dercert) | 
