diff options
Diffstat (limited to 'Lib/ssl.py')
| -rw-r--r-- | Lib/ssl.py | 92 | 
1 files changed, 56 insertions, 36 deletions
| @@ -92,7 +92,7 @@ import re  import sys  import os  from collections import namedtuple -from enum import Enum as _Enum +from enum import Enum as _Enum, IntEnum as _IntEnum  import _ssl             # if we can't import it, let the error propagate @@ -119,30 +119,19 @@ _import_symbols('SSL_ERROR_')  from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1  from _ssl import _OPENSSL_API_VERSION +_SSLMethod = _IntEnum('_SSLMethod', +                      {name: value for name, value in vars(_ssl).items() +                       if name.startswith('PROTOCOL_')}) +globals().update(_SSLMethod.__members__) + +_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} -_PROTOCOL_NAMES = { -    PROTOCOL_TLSv1: "TLSv1", -    PROTOCOL_SSLv23: "SSLv23", -    PROTOCOL_SSLv3: "SSLv3", -}  try: -    from _ssl import PROTOCOL_SSLv2      _SSLv2_IF_EXISTS = PROTOCOL_SSLv2 -except ImportError: +except NameError:      _SSLv2_IF_EXISTS = None -else: -    _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" - -try: -    from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2 -except ImportError: -    pass -else: -    _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1" -    _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2"  if sys.platform == "win32":      from _ssl import enum_certificates, enum_crls @@ -675,17 +664,7 @@ class SSLSocket(socket):                  raise ValueError(                      "non-zero flags not allowed in calls to send() on %s" %                      self.__class__) -            try: -                v = self._sslobj.write(data) -            except SSLError as x: -                if x.args[0] == SSL_ERROR_WANT_READ: -                    return 0 -                elif x.args[0] == SSL_ERROR_WANT_WRITE: -                    return 0 -                else: -                    raise -            else: -                return v +            return self._sslobj.write(data)          else:              return socket.send(self, data, flags) @@ -721,6 +700,16 @@ class SSLSocket(socket):          else:              return socket.sendall(self, data, flags) +    def sendfile(self, file, offset=0, count=None): +        """Send a file, possibly by using os.sendfile() if this is a +        clear-text socket.  Return the total number of bytes sent. +        """ +        if self._sslobj is None: +            # os.sendfile() works with plain sockets only +            return super().sendfile(file, offset, count) +        else: +            return self._sendfile_use_send(file, offset, count) +      def recv(self, buflen=1024, flags=0):          self._checkClosed()          if self._sslobj: @@ -872,6 +861,15 @@ class SSLSocket(socket):              return None          return self._sslobj.tls_unique_cb() +    def version(self): +        """ +        Return a string identifying the protocol version used by the +        current SSL channel, or None if there is no established channel. +        """ +        if self._sslobj is None: +            return None +        return self._sslobj.version() +  def wrap_socket(sock, keyfile=None, certfile=None,                  server_side=False, cert_reqs=CERT_NONE, @@ -890,12 +888,34 @@ def wrap_socket(sock, keyfile=None, certfile=None,  # some utility functions  def cert_time_to_seconds(cert_time): -    """Takes a date-time string in standard ASN1_print form -    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return -    a Python time value in seconds past the epoch.""" +    """Return the time in seconds since the Epoch, given the timestring +    representing the "notBefore" or "notAfter" date from a certificate +    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). + +    "notBefore" or "notAfter" dates must use UTC (RFC 5280). + +    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec +    UTC should be specified as GMT (see ASN1_TIME_print()) +    """ +    from time import strptime +    from calendar import timegm -    import time -    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) +    months = ( +        "Jan","Feb","Mar","Apr","May","Jun", +        "Jul","Aug","Sep","Oct","Nov","Dec" +    ) +    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT +    try: +        month_number = months.index(cert_time[:3].title()) + 1 +    except ValueError: +        raise ValueError('time data %r does not match ' +                         'format "%%b%s"' % (cert_time, time_format)) +    else: +        # found valid month +        tt = strptime(cert_time[3:], time_format) +        # return an integer, the previous mktime()-based implementation +        # returned a float (fractional seconds are always zero here). +        return timegm((tt[0], month_number) + tt[2:6])  PEM_HEADER = "-----BEGIN CERTIFICATE-----"  PEM_FOOTER = "-----END CERTIFICATE-----" @@ -922,7 +942,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):      d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]      return base64.decodebytes(d.encode('ASCII', 'strict')) -def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): +def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):      """Retrieve the certificate from the server at the specified address,      and return it as a PEM-encoded string.      If 'ca_certs' is specified, validate the server cert against it. | 
