diff options
author | Bill Janssen <janssen@parc.com> | 2007-09-16 22:06:00 (GMT) |
---|---|---|
committer | Bill Janssen <janssen@parc.com> | 2007-09-16 22:06:00 (GMT) |
commit | 296a59d3be01d6ac77fe674333104eb89fd5e695 (patch) | |
tree | 41fddf17b41c6df7a56fbb5e7bafa0c2e489c9cd /Lib/ssl.py | |
parent | 7e84c7f4b5ddf713e940c33ccb82cd1916e937b4 (diff) | |
download | cpython-296a59d3be01d6ac77fe674333104eb89fd5e695.zip cpython-296a59d3be01d6ac77fe674333104eb89fd5e695.tar.gz cpython-296a59d3be01d6ac77fe674333104eb89fd5e695.tar.bz2 |
Add support for asyncore server-side SSL support. This requires
adding the 'makefile' method to ssl.SSLSocket, and importing the
requisite fakefile class from socket.py, and making the appropriate
changes to it to make it use the SSL connection.
Added sample HTTPS server to test_ssl.py, and test that uses it.
Change SSL tests to use https://svn.python.org/, instead of
www.sf.net and pop.gmail.com.
Added utility function to ssl module, get_server_certificate,
to wrap up the several things to be done to pull a certificate
from a remote server.
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 336 |
1 files changed, 311 insertions, 25 deletions
@@ -55,7 +55,7 @@ PROTOCOL_SSLv23 PROTOCOL_TLSv1 """ -import os, sys +import os, sys, textwrap import _ssl # if we can't import it, let the error propagate @@ -76,19 +76,7 @@ from _ssl import \ from socket import socket from socket import getnameinfo as _getnameinfo - -def get_protocol_name (protocol_code): - if protocol_code == PROTOCOL_TLSv1: - return "TLSv1" - elif protocol_code == PROTOCOL_SSLv23: - return "SSLv23" - elif protocol_code == PROTOCOL_SSLv2: - return "SSLv2" - elif protocol_code == PROTOCOL_SSLv3: - return "SSLv3" - else: - return "<unknown>" - +import base64 # for DER-to-PEM translation class SSLSocket (socket): @@ -193,21 +181,12 @@ class SSLSocket (socket): else: return socket.recv_from(self, addr, buflen, flags) - def ssl_shutdown(self): - - """Shuts down the SSL channel over this socket (if active), - without closing the socket connection.""" - - if self._sslobj: - self._sslobj.shutdown() - self._sslobj = None - def shutdown(self, how): - self.ssl_shutdown() + self._sslobj = None socket.shutdown(self, how) def close(self): - self.ssl_shutdown() + self._sslobj = None socket.close(self) def connect(self, addr): @@ -236,6 +215,248 @@ class SSLSocket (socket): self.ca_certs), addr) + def makefile(self, mode='r', bufsize=-1): + + """Ouch. Need to make and return a file-like object that + works with the SSL connection.""" + + if self._sslobj: + return SSLFileStream(self._sslobj, mode, bufsize) + else: + return socket.makefile(self, mode, bufsize) + + +class SSLFileStream: + + """A class to simulate a file stream on top of a socket. + Most of this is just lifted from the socket module, and + adjusted to work with an SSL stream instead of a socket.""" + + + default_bufsize = 8192 + name = "<SSL stream>" + + __slots__ = ["mode", "bufsize", "softspace", + # "closed" is a property, see below + "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", + "_close", "_fileno"] + + def __init__(self, sslobj, mode='rb', bufsize=-1, close=False): + self._sslobj = sslobj + self.mode = mode # Not actually used in this version + if bufsize < 0: + bufsize = self.default_bufsize + self.bufsize = bufsize + self.softspace = False + if bufsize == 0: + self._rbufsize = 1 + elif bufsize == 1: + self._rbufsize = self.default_bufsize + else: + self._rbufsize = bufsize + self._wbufsize = bufsize + self._rbuf = "" # A string + self._wbuf = [] # A list of strings + self._close = close + self._fileno = -1 + + def _getclosed(self): + return self._sslobj is None + closed = property(_getclosed, doc="True if the file is closed") + + def fileno(self): + return self._fileno + + def close(self): + try: + if self._sslobj: + self.flush() + finally: + if self._close and self._sslobj: + self._sslobj.close() + self._sslobj = None + + def __del__(self): + try: + self.close() + except: + # close() may fail if __init__ didn't complete + pass + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + count = 0 + while (count < len(buffer)): + written = self._sslobj.write(buffer) + count += written + buffer = buffer[written:] + + def write(self, data): + data = str(data) # XXX Should really reject non-string non-buffers + if not data: + return + self._wbuf.append(data) + if (self._wbufsize == 0 or + self._wbufsize == 1 and '\n' in data or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def writelines(self, list): + # XXX We could do better here for very long lists + # XXX Should really reject non-string non-buffers + self._wbuf.extend(filter(None, map(str, list))) + if (self._wbufsize <= 1 or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def _get_wbuf_len(self): + buf_len = 0 + for x in self._wbuf: + buf_len += len(x) + return buf_len + + def read(self, size=-1): + data = self._rbuf + if size < 0: + # Read until EOF + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + while True: + data = self._sslobj.read(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self._sslobj.read(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self._sslobj.read(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sslobj.read(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sslobj.read(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readlines(self, sizehint=0): + total = 0 + list = [] + while True: + line = self.readline() + if not line: + break + list.append(line) + total += len(line) + if sizehint and total >= sizehint: + break + return list + + # Iterator protocols + + def __iter__(self): + return self + + def next(self): + line = self.readline() + if not line: + raise StopIteration + return line + + + + def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None): @@ -255,6 +476,71 @@ def cert_time_to_seconds(cert_time): import time return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) +PEM_HEADER = "-----BEGIN CERTIFICATE-----" +PEM_FOOTER = "-----END CERTIFICATE-----" + +def DER_cert_to_PEM_cert(der_cert_bytes): + + """Takes a certificate in binary DER format and returns the + PEM version of it as a string.""" + + if hasattr(base64, 'standard_b64encode'): + # preferred because older API gets line-length wrong + f = base64.standard_b64encode(der_cert_bytes) + return (PEM_HEADER + '\n' + + textwrap.fill(f, 64) + + PEM_FOOTER + '\n') + else: + return (PEM_HEADER + '\n' + + base64.encodestring(der_cert_bytes) + + PEM_FOOTER + '\n') + +def PEM_cert_to_DER_cert(pem_cert_string): + + """Takes a certificate in ASCII PEM format and returns the + DER-encoded version of it as a byte sequence""" + + if not pem_cert_string.startswith(PEM_HEADER): + raise ValueError("Invalid PEM encoding; must start with %s" + % PEM_HEADER) + if not pem_cert_string.strip().endswith(PEM_FOOTER): + raise ValueError("Invalid PEM encoding; must end with %s" + % PEM_FOOTER) + d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] + return base64.decodestring(d) + +def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): + + """Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt.""" + + host, port = addr + if (ca_certs is not None): + cert_reqs = CERT_REQUIRED + else: + cert_reqs = CERT_NONE + s = wrap_socket(socket(), ssl_version=ssl_version, + cert_reqs=cert_reqs, ca_certs=ca_certs) + s.connect(addr) + dercert = s.getpeercert(True) + s.close() + return DER_cert_to_PEM_cert(dercert) + +def get_protocol_name (protocol_code): + if protocol_code == PROTOCOL_TLSv1: + return "TLSv1" + elif protocol_code == PROTOCOL_SSLv23: + return "SSLv23" + elif protocol_code == PROTOCOL_SSLv2: + return "SSLv2" + elif protocol_code == PROTOCOL_SSLv3: + return "SSLv3" + else: + return "<unknown>" + + # a replacement for the old socket.ssl function def sslwrap_simple (sock, keyfile=None, certfile=None): |