summaryrefslogtreecommitdiffstats
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py361
1 files changed, 338 insertions, 23 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 388c931..9a120f2 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -8,11 +8,11 @@ This module provides some more Pythonic support for SSL.
Object types:
- sslsocket -- subtype of socket.socket which does SSL over the socket
+ SSLSocket -- subtype of socket.socket which does SSL over the socket
Exceptions:
- sslerror -- exception raised for I/O errors
+ SSLError -- exception raised for I/O errors
Functions:
@@ -57,12 +57,14 @@ PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""
-import os, sys
+import os, sys, textwrap
import _ssl # if we can't import it, let the error propagate
-from _ssl import sslerror
+
+from _ssl import SSLError
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
+from _ssl import RAND_status, RAND_egd, RAND_add
from _ssl import \
SSL_ERROR_ZERO_RETURN, \
SSL_ERROR_WANT_READ, \
@@ -76,9 +78,9 @@ from _ssl import \
from socket import socket
from socket import getnameinfo as _getnameinfo
+import base64 # for DER-to-PEM translation
-
-class sslsocket (socket):
+class SSLSocket (socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
@@ -121,14 +123,21 @@ class sslsocket (socket):
return self._sslobj.write(data)
- def getpeercert(self):
+ def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a
certificate was provided, but not validated."""
- return self._sslobj.peer_certificate()
+ return self._sslobj.peer_certificate(binary_form)
+
+ def cipher (self):
+
+ if not self._sslobj:
+ return None
+ else:
+ return self._sslobj.cipher()
def send (self, data, flags=0):
if self._sslobj:
@@ -174,21 +183,12 @@ class sslsocket (socket):
else:
return socket.recv_from(self, addr, buflen, flags)
- def ssl_shutdown(self):
-
- """Shuts down the SSL channel over this socket (if active),
- without closing the socket connection."""
-
- if self._sslobj:
- self._sslobj.shutdown()
- self._sslobj = None
-
def shutdown(self, how):
- self.ssl_shutdown()
+ self._sslobj = None
socket.shutdown(self, how)
def close(self):
- self.ssl_shutdown()
+ self._sslobj = None
socket.close(self)
def connect(self, addr):
@@ -199,7 +199,7 @@ class sslsocket (socket):
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._sslobj:
- raise ValueError("attempt to connect already-connected sslsocket!")
+ raise ValueError("attempt to connect already-connected SSLSocket!")
socket.connect(self, addr)
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
@@ -212,11 +212,261 @@ class sslsocket (socket):
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
- return (sslsocket(newsock, True, self.keyfile, self.certfile,
- self.cert_reqs, self.ssl_version,
- self.ca_certs), addr)
+ return (SSLSocket(newsock, True, self.keyfile, self.certfile,
+ self.cert_reqs, self.ssl_version,
+ self.ca_certs), addr)
+
+
+ def makefile(self, mode='r', bufsize=-1):
+
+ """Ouch. Need to make and return a file-like object that
+ works with the SSL connection."""
+
+ if self._sslobj:
+ return SSLFileStream(self._sslobj, mode, bufsize)
+ else:
+ return socket.makefile(self, mode, bufsize)
+
+
+class SSLFileStream:
+
+ """A class to simulate a file stream on top of a socket.
+ Most of this is just lifted from the socket module, and
+ adjusted to work with an SSL stream instead of a socket."""
+ default_bufsize = 8192
+ name = "<SSL stream>"
+
+ __slots__ = ["mode", "bufsize", "softspace",
+ # "closed" is a property, see below
+ "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
+ "_close", "_fileno"]
+
+ def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
+ self._sslobj = sslobj
+ self.mode = mode # Not actually used in this version
+ if bufsize < 0:
+ bufsize = self.default_bufsize
+ self.bufsize = bufsize
+ self.softspace = False
+ if bufsize == 0:
+ self._rbufsize = 1
+ elif bufsize == 1:
+ self._rbufsize = self.default_bufsize
+ else:
+ self._rbufsize = bufsize
+ self._wbufsize = bufsize
+ self._rbuf = "" # A string
+ self._wbuf = [] # A list of strings
+ self._close = close
+ self._fileno = -1
+
+ def _getclosed(self):
+ return self._sslobj is None
+ closed = property(_getclosed, doc="True if the file is closed")
+
+ def fileno(self):
+ return self._fileno
+
+ def close(self):
+ try:
+ if self._sslobj:
+ self.flush()
+ finally:
+ if self._close and self._sslobj:
+ self._sslobj.close()
+ self._sslobj = None
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ # close() may fail if __init__ didn't complete
+ pass
+
+ def flush(self):
+ if self._wbuf:
+ buffer = "".join(self._wbuf)
+ self._wbuf = []
+ count = 0
+ while (count < len(buffer)):
+ written = self._sslobj.write(buffer)
+ count += written
+ buffer = buffer[written:]
+
+ def write(self, data):
+ data = str(data) # XXX Should really reject non-string non-buffers
+ if not data:
+ return
+ self._wbuf.append(data)
+ if (self._wbufsize == 0 or
+ self._wbufsize == 1 and '\n' in data or
+ self._get_wbuf_len() >= self._wbufsize):
+ self.flush()
+
+ def writelines(self, list):
+ # XXX We could do better here for very long lists
+ # XXX Should really reject non-string non-buffers
+ self._wbuf.extend(filter(None, map(str, list)))
+ if (self._wbufsize <= 1 or
+ self._get_wbuf_len() >= self._wbufsize):
+ self.flush()
+
+ def _get_wbuf_len(self):
+ buf_len = 0
+ for x in self._wbuf:
+ buf_len += len(x)
+ return buf_len
+
+ def read(self, size=-1):
+ data = self._rbuf
+ if size < 0:
+ # Read until EOF
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ if self._rbufsize <= 1:
+ recv_size = self.default_bufsize
+ else:
+ recv_size = self._rbufsize
+ while True:
+ data = self._sslobj.read(recv_size)
+ if not data:
+ break
+ buffers.append(data)
+ return "".join(buffers)
+ else:
+ # Read until size bytes or EOF seen, whichever comes first
+ buf_len = len(data)
+ if buf_len >= size:
+ self._rbuf = data[size:]
+ return data[:size]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ left = size - buf_len
+ recv_size = max(self._rbufsize, left)
+ data = self._sslobj.read(recv_size)
+ if not data:
+ break
+ buffers.append(data)
+ n = len(data)
+ if n >= left:
+ self._rbuf = data[left:]
+ buffers[-1] = data[:left]
+ break
+ buf_len += n
+ return "".join(buffers)
+
+ def readline(self, size=-1):
+ data = self._rbuf
+ if size < 0:
+ # Read until \n or EOF, whichever comes first
+ if self._rbufsize <= 1:
+ # Speed up unbuffered case
+ assert data == ""
+ buffers = []
+ while data != "\n":
+ data = self._sslobj.read(1)
+ if not data:
+ break
+ buffers.append(data)
+ return "".join(buffers)
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ return data[:nl]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ data = self._sslobj.read(self._rbufsize)
+ if not data:
+ break
+ buffers.append(data)
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ buffers[-1] = data[:nl]
+ break
+ return "".join(buffers)
+ else:
+ # Read until size bytes or \n or EOF seen, whichever comes first
+ nl = data.find('\n', 0, size)
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ return data[:nl]
+ buf_len = len(data)
+ if buf_len >= size:
+ self._rbuf = data[size:]
+ return data[:size]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ data = self._sslobj.read(self._rbufsize)
+ if not data:
+ break
+ buffers.append(data)
+ left = size - buf_len
+ nl = data.find('\n', 0, left)
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ buffers[-1] = data[:nl]
+ break
+ n = len(data)
+ if n >= left:
+ self._rbuf = data[left:]
+ buffers[-1] = data[:left]
+ break
+ buf_len += n
+ return "".join(buffers)
+
+ def readlines(self, sizehint=0):
+ total = 0
+ list = []
+ while True:
+ line = self.readline()
+ if not line:
+ break
+ list.append(line)
+ total += len(line)
+ if sizehint and total >= sizehint:
+ break
+ return list
+
+ # Iterator protocols
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ line = self.readline()
+ if not line:
+ raise StopIteration
+ return line
+
+
+
+
+def wrap_socket(sock, keyfile=None, certfile=None,
+ server_side=False, cert_reqs=CERT_NONE,
+ ssl_version=PROTOCOL_SSLv23, ca_certs=None):
+
+ return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
+ server_side=server_side, cert_reqs=cert_reqs,
+ ssl_version=ssl_version, ca_certs=ca_certs)
+
# some utility functions
def cert_time_to_seconds(cert_time):
@@ -228,6 +478,71 @@ def cert_time_to_seconds(cert_time):
import time
return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
+PEM_HEADER = "-----BEGIN CERTIFICATE-----"
+PEM_FOOTER = "-----END CERTIFICATE-----"
+
+def DER_cert_to_PEM_cert(der_cert_bytes):
+
+ """Takes a certificate in binary DER format and returns the
+ PEM version of it as a string."""
+
+ if hasattr(base64, 'standard_b64encode'):
+ # preferred because older API gets line-length wrong
+ f = base64.standard_b64encode(der_cert_bytes)
+ return (PEM_HEADER + '\n' +
+ textwrap.fill(f, 64) +
+ PEM_FOOTER + '\n')
+ else:
+ return (PEM_HEADER + '\n' +
+ base64.encodestring(der_cert_bytes) +
+ PEM_FOOTER + '\n')
+
+def PEM_cert_to_DER_cert(pem_cert_string):
+
+ """Takes a certificate in ASCII PEM format and returns the
+ DER-encoded version of it as a byte sequence"""
+
+ if not pem_cert_string.startswith(PEM_HEADER):
+ raise ValueError("Invalid PEM encoding; must start with %s"
+ % PEM_HEADER)
+ if not pem_cert_string.strip().endswith(PEM_FOOTER):
+ raise ValueError("Invalid PEM encoding; must end with %s"
+ % PEM_FOOTER)
+ d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
+ return base64.decodestring(d)
+
+def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
+
+ """Retrieve the certificate from the server at the specified address,
+ and return it as a PEM-encoded string.
+ If 'ca_certs' is specified, validate the server cert against it.
+ If 'ssl_version' is specified, use it in the connection attempt."""
+
+ host, port = addr
+ if (ca_certs is not None):
+ cert_reqs = CERT_REQUIRED
+ else:
+ cert_reqs = CERT_NONE
+ s = wrap_socket(socket(), ssl_version=ssl_version,
+ cert_reqs=cert_reqs, ca_certs=ca_certs)
+ s.connect(addr)
+ dercert = s.getpeercert(True)
+ s.close()
+ return DER_cert_to_PEM_cert(dercert)
+
+def get_protocol_name (protocol_code):
+ if protocol_code == PROTOCOL_TLSv1:
+ return "TLSv1"
+ elif protocol_code == PROTOCOL_SSLv23:
+ return "SSLv23"
+ elif protocol_code == PROTOCOL_SSLv2:
+ return "SSLv2"
+ elif protocol_code == PROTOCOL_SSLv3:
+ return "SSLv3"
+ else:
+ return "<unknown>"
+
+
# a replacement for the old socket.ssl function
def sslwrap_simple (sock, keyfile=None, certfile=None):