summaryrefslogtreecommitdiffstats
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py255
1 files changed, 195 insertions, 60 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index d9d1916..8854e37 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -92,12 +92,12 @@ import re
import sys
import os
from collections import namedtuple
-from enum import Enum as _Enum
+from enum import Enum as _Enum, IntEnum as _IntEnum
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
-from _ssl import _SSLContext
+from _ssl import _SSLContext, MemoryBIO
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError,
@@ -119,30 +119,19 @@ _import_symbols('SSL_ERROR_')
from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
-from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
from _ssl import _OPENSSL_API_VERSION
+_SSLMethod = _IntEnum('_SSLMethod',
+ {name: value for name, value in vars(_ssl).items()
+ if name.startswith('PROTOCOL_')})
+globals().update(_SSLMethod.__members__)
+
+_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
-_PROTOCOL_NAMES = {
- PROTOCOL_TLSv1: "TLSv1",
- PROTOCOL_SSLv23: "SSLv23",
- PROTOCOL_SSLv3: "SSLv3",
-}
try:
- from _ssl import PROTOCOL_SSLv2
_SSLv2_IF_EXISTS = PROTOCOL_SSLv2
-except ImportError:
+except NameError:
_SSLv2_IF_EXISTS = None
-else:
- _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
-
-try:
- from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2
-except ImportError:
- pass
-else:
- _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1"
- _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2"
if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
@@ -363,6 +352,12 @@ class SSLContext(_SSLContext):
server_hostname=server_hostname,
_context=self)
+ def wrap_bio(self, incoming, outgoing, server_side=False,
+ server_hostname=None):
+ sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
+ server_hostname=server_hostname)
+ return SSLObject(sslobj)
+
def set_npn_protocols(self, npn_protocols):
protos = bytearray()
for protocol in npn_protocols:
@@ -480,6 +475,129 @@ def _create_stdlib_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None,
return context
+
+class SSLObject:
+ """This class implements an interface on top of a low-level SSL object as
+ implemented by OpenSSL. This object captures the state of an SSL connection
+ but does not provide any network IO itself. IO needs to be performed
+ through separate "BIO" objects which are OpenSSL's IO abstraction layer.
+
+ This class does not have a public constructor. Instances are returned by
+ ``SSLContext.wrap_bio``. This class is typically used by framework authors
+ that want to implement asynchronous IO for SSL through memory buffers.
+
+ When compared to ``SSLSocket``, this object lacks the following features:
+
+ * Any form of network IO incluging methods such as ``recv`` and ``send``.
+ * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
+ """
+
+ def __init__(self, sslobj, owner=None):
+ self._sslobj = sslobj
+ # Note: _sslobj takes a weak reference to owner
+ self._sslobj.owner = owner or self
+
+ @property
+ def context(self):
+ """The SSLContext that is currently in use."""
+ return self._sslobj.context
+
+ @context.setter
+ def context(self, ctx):
+ self._sslobj.context = ctx
+
+ @property
+ def server_side(self):
+ """Whether this is a server-side socket."""
+ return self._sslobj.server_side
+
+ @property
+ def server_hostname(self):
+ """The currently set server hostname (for SNI), or ``None`` if no
+ server hostame is set."""
+ return self._sslobj.server_hostname
+
+ def read(self, len=0, buffer=None):
+ """Read up to 'len' bytes from the SSL object and return them.
+
+ If 'buffer' is provided, read into this buffer and return the number of
+ bytes read.
+ """
+ if buffer is not None:
+ v = self._sslobj.read(len, buffer)
+ else:
+ v = self._sslobj.read(len or 1024)
+ return v
+
+ def write(self, data):
+ """Write 'data' to the SSL object and return the number of bytes
+ written.
+
+ The 'data' argument must support the buffer interface.
+ """
+ return self._sslobj.write(data)
+
+ def getpeercert(self, binary_form=False):
+ """Returns a formatted version of the data in the certificate provided
+ by the other end of the SSL channel.
+
+ Return None if no certificate was provided, {} if a certificate was
+ provided, but not validated.
+ """
+ return self._sslobj.peer_certificate(binary_form)
+
+ def selected_npn_protocol(self):
+ """Return the currently selected NPN protocol as a string, or ``None``
+ if a next protocol was not negotiated or if NPN is not supported by one
+ of the peers."""
+ if _ssl.HAS_NPN:
+ return self._sslobj.selected_npn_protocol()
+
+ def cipher(self):
+ """Return the currently selected cipher as a 3-tuple ``(name,
+ ssl_version, secret_bits)``."""
+ return self._sslobj.cipher()
+
+ def compression(self):
+ """Return the current compression algorithm in use, or ``None`` if
+ compression was not negotiated or not supported by one of the peers."""
+ return self._sslobj.compression()
+
+ def pending(self):
+ """Return the number of bytes that can be read immediately."""
+ return self._sslobj.pending()
+
+ def do_handshake(self):
+ """Start the SSL/TLS handshake."""
+ self._sslobj.do_handshake()
+ if self.context.check_hostname:
+ if not self.server_hostname:
+ raise ValueError("check_hostname needs server_hostname "
+ "argument")
+ match_hostname(self.getpeercert(), self.server_hostname)
+
+ def unwrap(self):
+ """Start the SSL shutdown handshake."""
+ return self._sslobj.shutdown()
+
+ def get_channel_binding(self, cb_type="tls-unique"):
+ """Get channel binding data for current connection. Raise ValueError
+ if the requested `cb_type` is not supported. Return bytes of the data
+ or None if the data is not available (e.g. before the handshake)."""
+ if cb_type not in CHANNEL_BINDING_TYPES:
+ raise ValueError("Unsupported channel binding type")
+ if cb_type != "tls-unique":
+ raise NotImplementedError(
+ "{0} channel binding type not implemented"
+ .format(cb_type))
+ return self._sslobj.tls_unique_cb()
+
+ def version(self):
+ """Return a string identifying the protocol version used by the
+ current SSL channel. """
+ return self._sslobj.version()
+
+
class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
@@ -567,8 +685,9 @@ class SSLSocket(socket):
if connected:
# create the SSL object
try:
- self._sslobj = self._context._wrap_socket(self, server_side,
- server_hostname)
+ sslobj = self._context._wrap_socket(self, server_side,
+ server_hostname)
+ self._sslobj = SSLObject(sslobj, owner=self)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@@ -613,11 +732,7 @@ class SSLSocket(socket):
if not self._sslobj:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
- if buffer is not None:
- v = self._sslobj.read(len, buffer)
- else:
- v = self._sslobj.read(len or 1024)
- return v
+ return self._sslobj.read(len, buffer)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
@@ -644,7 +759,7 @@ class SSLSocket(socket):
self._checkClosed()
self._check_connected()
- return self._sslobj.peer_certificate(binary_form)
+ return self._sslobj.getpeercert(binary_form)
def selected_npn_protocol(self):
self._checkClosed()
@@ -674,17 +789,7 @@ class SSLSocket(socket):
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
- try:
- v = self._sslobj.write(data)
- except SSLError as x:
- if x.args[0] == SSL_ERROR_WANT_READ:
- return 0
- elif x.args[0] == SSL_ERROR_WANT_WRITE:
- return 0
- else:
- raise
- else:
- return v
+ return self._sslobj.write(data)
else:
return socket.send(self, data, flags)
@@ -720,6 +825,16 @@ class SSLSocket(socket):
else:
return socket.sendall(self, data, flags)
+ def sendfile(self, file, offset=0, count=None):
+ """Send a file, possibly by using os.sendfile() if this is a
+ clear-text socket. Return the total number of bytes sent.
+ """
+ if self._sslobj is None:
+ # os.sendfile() works with plain sockets only
+ return super().sendfile(file, offset, count)
+ else:
+ return self._sendfile_use_send(file, offset, count)
+
def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
@@ -784,7 +899,7 @@ class SSLSocket(socket):
def unwrap(self):
if self._sslobj:
- s = self._sslobj.shutdown()
+ s = self._sslobj.unwrap()
self._sslobj = None
return s
else:
@@ -805,12 +920,6 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
- if self.context.check_hostname:
- if not self.server_hostname:
- raise ValueError("check_hostname needs server_hostname "
- "argument")
- match_hostname(self.getpeercert(), self.server_hostname)
-
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@@ -818,7 +927,8 @@ class SSLSocket(socket):
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
- self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
+ sslobj = self.context._wrap_socket(self, False, self.server_hostname)
+ self._sslobj = SSLObject(sslobj, owner=self)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
@@ -861,15 +971,18 @@ class SSLSocket(socket):
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
- if cb_type not in CHANNEL_BINDING_TYPES:
- raise ValueError("Unsupported channel binding type")
- if cb_type != "tls-unique":
- raise NotImplementedError(
- "{0} channel binding type not implemented"
- .format(cb_type))
if self._sslobj is None:
return None
- return self._sslobj.tls_unique_cb()
+ return self._sslobj.get_channel_binding(cb_type)
+
+ def version(self):
+ """
+ Return a string identifying the protocol version used by the
+ current SSL channel, or None if there is no established channel.
+ """
+ if self._sslobj is None:
+ return None
+ return self._sslobj.version()
def wrap_socket(sock, keyfile=None, certfile=None,
@@ -889,12 +1002,34 @@ def wrap_socket(sock, keyfile=None, certfile=None,
# some utility functions
def cert_time_to_seconds(cert_time):
- """Takes a date-time string in standard ASN1_print form
- ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
- a Python time value in seconds past the epoch."""
+ """Return the time in seconds since the Epoch, given the timestring
+ representing the "notBefore" or "notAfter" date from a certificate
+ in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
+
+ "notBefore" or "notAfter" dates must use UTC (RFC 5280).
+
+ Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
+ UTC should be specified as GMT (see ASN1_TIME_print())
+ """
+ from time import strptime
+ from calendar import timegm
- import time
- return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
+ months = (
+ "Jan","Feb","Mar","Apr","May","Jun",
+ "Jul","Aug","Sep","Oct","Nov","Dec"
+ )
+ time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
+ try:
+ month_number = months.index(cert_time[:3].title()) + 1
+ except ValueError:
+ raise ValueError('time data %r does not match '
+ 'format "%%b%s"' % (cert_time, time_format))
+ else:
+ # found valid month
+ tt = strptime(cert_time[3:], time_format)
+ # return an integer, the previous mktime()-based implementation
+ # returned a float (fractional seconds are always zero here).
+ return timegm((tt[0], month_number) + tt[2:6])
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"
@@ -921,7 +1056,7 @@ def PEM_cert_to_DER_cert(pem_cert_string):
d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
return base64.decodebytes(d.encode('ASCII', 'strict'))
-def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
+def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it.