diff options
Diffstat (limited to 'Lib/ssl.py')
| -rw-r--r-- | Lib/ssl.py | 192 | 
1 files changed, 158 insertions, 34 deletions
| @@ -55,13 +55,16 @@ PROTOCOL_TLSv1  """  import textwrap +import re  import _ssl             # if we can't import it, let the error propagate -from _ssl import SSLError +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import _SSLContext, SSLError  from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED  from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,                    PROTOCOL_TLSv1) +from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1  from _ssl import RAND_status, RAND_egd, RAND_add  from _ssl import (      SSL_ERROR_ZERO_RETURN, @@ -74,17 +77,98 @@ from _ssl import (      SSL_ERROR_EOF,      SSL_ERROR_INVALID_ERROR_CODE,      ) +from _ssl import HAS_SNI  from socket import getnameinfo as _getnameinfo  from socket import error as socket_error -from socket import dup as _dup  from socket import socket, AF_INET, SOCK_STREAM  import base64        # for DER-to-PEM translation  import traceback  import errno -class SSLSocket(socket): +class CertificateError(ValueError): +    pass + + +def _dnsname_to_pat(dn): +    pats = [] +    for frag in dn.split(r'.'): +        if frag == '*': +            # When '*' is a fragment by itself, it matches a non-empty dotless +            # fragment. +            pats.append('[^.]+') +        else: +            # Otherwise, '*' matches any dotless fragment. +            frag = re.escape(frag) +            pats.append(frag.replace(r'\*', '[^.]*')) +    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): +    """Verify that *cert* (in decoded format as returned by +    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules +    are mostly followed, but IP addresses are not accepted for *hostname*. + +    CertificateError is raised on failure. On success, the function +    returns nothing. +    """ +    if not cert: +        raise ValueError("empty or no certificate") +    dnsnames = [] +    san = cert.get('subjectAltName', ()) +    for key, value in san: +        if key == 'DNS': +            if _dnsname_to_pat(value).match(hostname): +                return +            dnsnames.append(value) +    if not san: +        # The subject is only checked when subjectAltName is empty +        for sub in cert.get('subject', ()): +            for key, value in sub: +                # XXX according to RFC 2818, the most specific Common Name +                # must be used. +                if key == 'commonName': +                    if _dnsname_to_pat(value).match(hostname): +                        return +                    dnsnames.append(value) +    if len(dnsnames) > 1: +        raise CertificateError("hostname %r " +            "doesn't match either of %s" +            % (hostname, ', '.join(map(repr, dnsnames)))) +    elif len(dnsnames) == 1: +        raise CertificateError("hostname %r " +            "doesn't match %r" +            % (hostname, dnsnames[0])) +    else: +        raise CertificateError("no appropriate commonName or " +            "subjectAltName fields were found") + + +class SSLContext(_SSLContext): +    """An SSLContext holds various SSL-related configuration options and +    data, such as certificates and possibly a private key.""" + +    __slots__ = ('protocol',) + +    def __new__(cls, protocol, *args, **kwargs): +        return _SSLContext.__new__(cls, protocol) + +    def __init__(self, protocol): +        self.protocol = protocol + +    def wrap_socket(self, sock, server_side=False, +                    do_handshake_on_connect=True, +                    suppress_ragged_eofs=True, +                    server_hostname=None): +        return SSLSocket(sock=sock, server_side=server_side, +                         do_handshake_on_connect=do_handshake_on_connect, +                         suppress_ragged_eofs=suppress_ragged_eofs, +                         server_hostname=server_hostname, +                         _context=self) + + +class SSLSocket(socket):      """This class implements a subtype of socket.socket that wraps      the underlying OS socket in an SSL context when necessary, and      provides read and write methods over that channel.""" @@ -94,15 +178,48 @@ class SSLSocket(socket):                   ssl_version=PROTOCOL_SSLv23, ca_certs=None,                   do_handshake_on_connect=True,                   family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, -                 suppress_ragged_eofs=True): +                 suppress_ragged_eofs=True, ciphers=None, +                 server_hostname=None, +                 _context=None): +        if _context: +            self.context = _context +        else: +            if server_side and not certfile: +                raise ValueError("certfile must be specified for server-side " +                                 "operations") +            if keyfile and not certfile: +                raise ValueError("certfile must be specified") +            if certfile and not keyfile: +                keyfile = certfile +            self.context = SSLContext(ssl_version) +            self.context.verify_mode = cert_reqs +            if ca_certs: +                self.context.load_verify_locations(ca_certs) +            if certfile: +                self.context.load_cert_chain(certfile, keyfile) +            if ciphers: +                self.context.set_ciphers(ciphers) +            self.keyfile = keyfile +            self.certfile = certfile +            self.cert_reqs = cert_reqs +            self.ssl_version = ssl_version +            self.ca_certs = ca_certs +            self.ciphers = ciphers +        if server_side and server_hostname: +            raise ValueError("server_hostname can only be specified " +                             "in client mode") +        self.server_side = server_side +        self.server_hostname = server_hostname +        self.do_handshake_on_connect = do_handshake_on_connect +        self.suppress_ragged_eofs = suppress_ragged_eofs          connected = False          if sock is not None:              socket.__init__(self,                              family=sock.family,                              type=sock.type,                              proto=sock.proto, -                            fileno=_dup(sock.fileno())) +                            fileno=sock.fileno())              self.settimeout(sock.gettimeout())              # see if it's connected              try: @@ -112,23 +229,20 @@ class SSLSocket(socket):                      raise              else:                  connected = True -            sock.close() +            sock.detach()          elif fileno is not None:              socket.__init__(self, fileno=fileno)          else:              socket.__init__(self, family=family, type=type, proto=proto) -        if certfile and not keyfile: -            keyfile = certfile -          self._closed = False          self._sslobj = None +        self._connected = connected          if connected:              # create the SSL object              try: -                self._sslobj = _ssl.sslwrap(self, server_side, -                                            keyfile, certfile, -                                            cert_reqs, ssl_version, ca_certs) +                self._sslobj = self.context._wrap_socket(self, server_side, +                                                         server_hostname)                  if do_handshake_on_connect:                      timeout = self.gettimeout()                      if timeout == 0.0: @@ -140,14 +254,6 @@ class SSLSocket(socket):                  self.close()                  raise x -        self.keyfile = keyfile -        self.certfile = certfile -        self.cert_reqs = cert_reqs -        self.ssl_version = ssl_version -        self.ca_certs = ca_certs -        self.do_handshake_on_connect = do_handshake_on_connect -        self.suppress_ragged_eofs = suppress_ragged_eofs -      def dup(self):          raise NotImplemented("Can't dup() %s instances" %                               self.__class__.__name__) @@ -234,6 +340,10 @@ class SSLSocket(socket):      def sendall(self, data, flags=0):          self._checkClosed()          if self._sslobj: +            if flags != 0: +                raise ValueError( +                    "non-zero flags not allowed in calls to sendall() on %s" % +                    self.__class__)              amount = len(data)              count = 0              while (count < amount): @@ -321,24 +431,36 @@ class SSLSocket(socket):          finally:              self.settimeout(timeout) -    def connect(self, addr): -        """Connects to remote ADDR, and then wraps the connection in -        an SSL channel.""" - +    def _real_connect(self, addr, return_errno): +        if self.server_side: +            raise ValueError("can't connect in server-side mode")          # Here we assume that the socket is client-side, and not          # connected at the time of the call.  We connect it, then wrap it. -        if self._sslobj: +        if self._connected:              raise ValueError("attempt to connect already-connected SSLSocket!") -        socket.connect(self, addr) -        self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile, -                                    self.cert_reqs, self.ssl_version, -                                    self.ca_certs) +        self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)          try: +            socket.connect(self, addr)              if self.do_handshake_on_connect:                  self.do_handshake() -        except: -            self._sslobj = None -            raise +        except socket_error as e: +            if return_errno: +                return e.errno +            else: +                self._sslobj = None +                raise e +        self._connected = True +        return 0 + +    def connect(self, addr): +        """Connects to remote ADDR, and then wraps the connection in +        an SSL channel.""" +        self._real_connect(addr, False) + +    def connect_ex(self, addr): +        """Connects to remote ADDR, and then wraps the connection in +        an SSL channel.""" +        return self._real_connect(addr, True)      def accept(self):          """Accepts a new connection from a remote client, and returns @@ -352,6 +474,7 @@ class SSLSocket(socket):                            cert_reqs=self.cert_reqs,                            ssl_version=self.ssl_version,                            ca_certs=self.ca_certs, +                          ciphers=self.ciphers,                            do_handshake_on_connect=                                self.do_handshake_on_connect),                  addr) @@ -365,13 +488,14 @@ def wrap_socket(sock, keyfile=None, certfile=None,                  server_side=False, cert_reqs=CERT_NONE,                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,                  do_handshake_on_connect=True, -                suppress_ragged_eofs=True): +                suppress_ragged_eofs=True, ciphers=None):      return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,                       server_side=server_side, cert_reqs=cert_reqs,                       ssl_version=ssl_version, ca_certs=ca_certs,                       do_handshake_on_connect=do_handshake_on_connect, -                     suppress_ragged_eofs=suppress_ragged_eofs) +                     suppress_ragged_eofs=suppress_ragged_eofs, +                     ciphers=ciphers)  # some utility functions | 
