diff options
Diffstat (limited to 'Lib/ssl.py')
| -rw-r--r-- | Lib/ssl.py | 100 | 
1 files changed, 87 insertions, 13 deletions
@@ -60,10 +60,25 @@ import re  import _ssl             # if we can't import it, let the error propagate  from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, SSLError +from _ssl import _SSLContext +from _ssl import ( +    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, +    SSLSyscallError, SSLEOFError, +    )  from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED -from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 -from _ssl import RAND_status, RAND_egd, RAND_add +from _ssl import ( +    OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1, +    OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE +    ) +try: +    from _ssl import OP_NO_COMPRESSION +except ImportError: +    pass +try: +    from _ssl import OP_SINGLE_ECDH_USE +except ImportError: +    pass +from _ssl import RAND_status, RAND_egd, RAND_add, RAND_bytes, RAND_pseudo_bytes  from _ssl import (      SSL_ERROR_ZERO_RETURN,      SSL_ERROR_WANT_READ, @@ -75,8 +90,9 @@ from _ssl import (      SSL_ERROR_EOF,      SSL_ERROR_INVALID_ERROR_CODE,      ) -from _ssl import HAS_SNI -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN +from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, +                  PROTOCOL_TLSv1)  from _ssl import _OPENSSL_API_VERSION  _PROTOCOL_NAMES = { @@ -94,11 +110,16 @@ else:  from socket import getnameinfo as _getnameinfo  from socket import error as socket_error -from socket import socket, AF_INET, SOCK_STREAM +from socket import socket, AF_INET, SOCK_STREAM, create_connection  import base64        # for DER-to-PEM translation  import traceback  import errno +if _ssl.HAS_TLS_UNIQUE: +    CHANNEL_BINDING_TYPES = ['tls-unique'] +else: +    CHANNEL_BINDING_TYPES = [] +  # Disable weak or insecure ciphers by default  # (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')  _DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2' @@ -188,6 +209,17 @@ class SSLContext(_SSLContext):                           server_hostname=server_hostname,                           _context=self) +    def set_npn_protocols(self, npn_protocols): +        protos = bytearray() +        for protocol in npn_protocols: +            b = bytes(protocol, 'ascii') +            if len(b) == 0 or len(b) > 255: +                raise SSLError('NPN protocols must be 1 to 255 in length') +            protos.append(len(b)) +            protos.extend(b) + +        self._set_npn_protocols(protos) +  class SSLSocket(socket):      """This class implements a subtype of socket.socket that wraps @@ -199,7 +231,7 @@ class SSLSocket(socket):                   ssl_version=PROTOCOL_SSLv23, ca_certs=None,                   do_handshake_on_connect=True,                   family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, -                 suppress_ragged_eofs=True, ciphers=None, +                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,                   server_hostname=None,                   _context=None): @@ -219,6 +251,8 @@ class SSLSocket(socket):                  self.context.load_verify_locations(ca_certs)              if certfile:                  self.context.load_cert_chain(certfile, keyfile) +            if npn_protocols: +                self.context.set_npn_protocols(npn_protocols)              if ciphers:                  self.context.set_ciphers(ciphers)              self.keyfile = keyfile @@ -319,6 +353,13 @@ class SSLSocket(socket):          self._checkClosed()          return self._sslobj.peer_certificate(binary_form) +    def selected_npn_protocol(self): +        self._checkClosed() +        if not self._sslobj or not _ssl.HAS_NPN: +            return None +        else: +            return self._sslobj.selected_npn_protocol() +      def cipher(self):          self._checkClosed()          if not self._sslobj: @@ -326,6 +367,13 @@ class SSLSocket(socket):          else:              return self._sslobj.cipher() +    def compression(self): +        self._checkClosed() +        if not self._sslobj: +            return None +        else: +            return self._sslobj.compression() +      def send(self, data, flags=0):          self._checkClosed()          if self._sslobj: @@ -358,6 +406,12 @@ class SSLSocket(socket):          else:              return socket.sendto(self, data, flags_or_addr, addr) +    def sendmsg(self, *args, **kwargs): +        # Ensure programs don't send data unencrypted if they try to +        # use this method. +        raise NotImplementedError("sendmsg not allowed on instances of %s" % +                                  self.__class__) +      def sendall(self, data, flags=0):          self._checkClosed()          if self._sslobj: @@ -416,6 +470,14 @@ class SSLSocket(socket):          else:              return socket.recvfrom_into(self, buffer, nbytes, flags) +    def recvmsg(self, *args, **kwargs): +        raise NotImplementedError("recvmsg not allowed on instances of %s" % +                                  self.__class__) + +    def recvmsg_into(self, *args, **kwargs): +        raise NotImplementedError("recvmsg_into not allowed on instances of " +                                  "%s" % self.__class__) +      def pending(self):          self._checkClosed()          if self._sslobj: @@ -497,16 +559,28 @@ class SSLSocket(socket):                      server_side=True)          return newsock, addr -    def __del__(self): -        # sys.stderr.write("__del__ on %s\n" % repr(self)) -        self._real_close() +    def get_channel_binding(self, cb_type="tls-unique"): +        """Get channel binding data for current connection.  Raise ValueError +        if the requested `cb_type` is not supported.  Return bytes of the data +        or None if the data is not available (e.g. before the handshake). +        """ +        if cb_type not in CHANNEL_BINDING_TYPES: +            raise ValueError("Unsupported channel binding type") +        if cb_type != "tls-unique": +            raise NotImplementedError( +                            "{0} channel binding type not implemented" +                            .format(cb_type)) +        if self._sslobj is None: +            return None +        return self._sslobj.tls_unique_cb()  def wrap_socket(sock, keyfile=None, certfile=None,                  server_side=False, cert_reqs=CERT_NONE,                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,                  do_handshake_on_connect=True, -                suppress_ragged_eofs=True, ciphers=None): +                suppress_ragged_eofs=True, +                ciphers=None):      return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,                       server_side=server_side, cert_reqs=cert_reqs, @@ -561,9 +635,9 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):          cert_reqs = CERT_REQUIRED      else:          cert_reqs = CERT_NONE -    s = wrap_socket(socket(), ssl_version=ssl_version, +    s = create_connection(addr) +    s = wrap_socket(s, ssl_version=ssl_version,                      cert_reqs=cert_reqs, ca_certs=ca_certs) -    s.connect(addr)      dercert = s.getpeercert(True)      s.close()      return DER_cert_to_PEM_cert(dercert)  | 
