diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 361 |
1 files changed, 338 insertions, 23 deletions
@@ -8,11 +8,11 @@ This module provides some more Pythonic support for SSL. Object types: - sslsocket -- subtype of socket.socket which does SSL over the socket + SSLSocket -- subtype of socket.socket which does SSL over the socket Exceptions: - sslerror -- exception raised for I/O errors + SSLError -- exception raised for I/O errors Functions: @@ -57,12 +57,14 @@ PROTOCOL_SSLv23 PROTOCOL_TLSv1 """ -import os, sys +import os, sys, textwrap import _ssl # if we can't import it, let the error propagate -from _ssl import sslerror + +from _ssl import SSLError from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import RAND_status, RAND_egd, RAND_add from _ssl import \ SSL_ERROR_ZERO_RETURN, \ SSL_ERROR_WANT_READ, \ @@ -76,9 +78,9 @@ from _ssl import \ from socket import socket from socket import getnameinfo as _getnameinfo +import base64 # for DER-to-PEM translation - -class sslsocket (socket): +class SSLSocket (socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and @@ -121,14 +123,21 @@ class sslsocket (socket): return self._sslobj.write(data) - def getpeercert(self): + def getpeercert(self, binary_form=False): """Returns a formatted version of the data in the certificate provided by the other end of the SSL channel. Return None if no certificate was provided, {} if a certificate was provided, but not validated.""" - return self._sslobj.peer_certificate() + return self._sslobj.peer_certificate(binary_form) + + def cipher (self): + + if not self._sslobj: + return None + else: + return self._sslobj.cipher() def send (self, data, flags=0): if self._sslobj: @@ -174,21 +183,12 @@ class sslsocket (socket): else: return socket.recv_from(self, addr, buflen, flags) - def ssl_shutdown(self): - - """Shuts down the SSL channel over this socket (if active), - without closing the socket connection.""" - - if self._sslobj: - self._sslobj.shutdown() - self._sslobj = None - def shutdown(self, how): - self.ssl_shutdown() + self._sslobj = None socket.shutdown(self, how) def close(self): - self.ssl_shutdown() + self._sslobj = None socket.close(self) def connect(self, addr): @@ -199,7 +199,7 @@ class sslsocket (socket): # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. if self._sslobj: - raise ValueError("attempt to connect already-connected sslsocket!") + raise ValueError("attempt to connect already-connected SSLSocket!") socket.connect(self, addr) self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, self.cert_reqs, self.ssl_version, @@ -212,11 +212,261 @@ class sslsocket (socket): SSL channel, and the address of the remote client.""" newsock, addr = socket.accept(self) - return (sslsocket(newsock, True, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs), addr) + return (SSLSocket(newsock, True, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs), addr) + + + def makefile(self, mode='r', bufsize=-1): + + """Ouch. Need to make and return a file-like object that + works with the SSL connection.""" + + if self._sslobj: + return SSLFileStream(self._sslobj, mode, bufsize) + else: + return socket.makefile(self, mode, bufsize) + + +class SSLFileStream: + + """A class to simulate a file stream on top of a socket. + Most of this is just lifted from the socket module, and + adjusted to work with an SSL stream instead of a socket.""" + default_bufsize = 8192 + name = "<SSL stream>" + + __slots__ = ["mode", "bufsize", "softspace", + # "closed" is a property, see below + "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", + "_close", "_fileno"] + + def __init__(self, sslobj, mode='rb', bufsize=-1, close=False): + self._sslobj = sslobj + self.mode = mode # Not actually used in this version + if bufsize < 0: + bufsize = self.default_bufsize + self.bufsize = bufsize + self.softspace = False + if bufsize == 0: + self._rbufsize = 1 + elif bufsize == 1: + self._rbufsize = self.default_bufsize + else: + self._rbufsize = bufsize + self._wbufsize = bufsize + self._rbuf = "" # A string + self._wbuf = [] # A list of strings + self._close = close + self._fileno = -1 + + def _getclosed(self): + return self._sslobj is None + closed = property(_getclosed, doc="True if the file is closed") + + def fileno(self): + return self._fileno + + def close(self): + try: + if self._sslobj: + self.flush() + finally: + if self._close and self._sslobj: + self._sslobj.close() + self._sslobj = None + + def __del__(self): + try: + self.close() + except: + # close() may fail if __init__ didn't complete + pass + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + count = 0 + while (count < len(buffer)): + written = self._sslobj.write(buffer) + count += written + buffer = buffer[written:] + + def write(self, data): + data = str(data) # XXX Should really reject non-string non-buffers + if not data: + return + self._wbuf.append(data) + if (self._wbufsize == 0 or + self._wbufsize == 1 and '\n' in data or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def writelines(self, list): + # XXX We could do better here for very long lists + # XXX Should really reject non-string non-buffers + self._wbuf.extend(filter(None, map(str, list))) + if (self._wbufsize <= 1 or + self._get_wbuf_len() >= self._wbufsize): + self.flush() + + def _get_wbuf_len(self): + buf_len = 0 + for x in self._wbuf: + buf_len += len(x) + return buf_len + + def read(self, size=-1): + data = self._rbuf + if size < 0: + # Read until EOF + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + while True: + data = self._sslobj.read(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self._sslobj.read(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self._sslobj.read(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sslobj.read(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self._sslobj.read(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readlines(self, sizehint=0): + total = 0 + list = [] + while True: + line = self.readline() + if not line: + break + list.append(line) + total += len(line) + if sizehint and total >= sizehint: + break + return list + + # Iterator protocols + + def __iter__(self): + return self + + def next(self): + line = self.readline() + if not line: + raise StopIteration + return line + + + + +def wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None): + + return SSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs) + # some utility functions def cert_time_to_seconds(cert_time): @@ -228,6 +478,71 @@ def cert_time_to_seconds(cert_time): import time return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) +PEM_HEADER = "-----BEGIN CERTIFICATE-----" +PEM_FOOTER = "-----END CERTIFICATE-----" + +def DER_cert_to_PEM_cert(der_cert_bytes): + + """Takes a certificate in binary DER format and returns the + PEM version of it as a string.""" + + if hasattr(base64, 'standard_b64encode'): + # preferred because older API gets line-length wrong + f = base64.standard_b64encode(der_cert_bytes) + return (PEM_HEADER + '\n' + + textwrap.fill(f, 64) + + PEM_FOOTER + '\n') + else: + return (PEM_HEADER + '\n' + + base64.encodestring(der_cert_bytes) + + PEM_FOOTER + '\n') + +def PEM_cert_to_DER_cert(pem_cert_string): + + """Takes a certificate in ASCII PEM format and returns the + DER-encoded version of it as a byte sequence""" + + if not pem_cert_string.startswith(PEM_HEADER): + raise ValueError("Invalid PEM encoding; must start with %s" + % PEM_HEADER) + if not pem_cert_string.strip().endswith(PEM_FOOTER): + raise ValueError("Invalid PEM encoding; must end with %s" + % PEM_FOOTER) + d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] + return base64.decodestring(d) + +def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): + + """Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt.""" + + host, port = addr + if (ca_certs is not None): + cert_reqs = CERT_REQUIRED + else: + cert_reqs = CERT_NONE + s = wrap_socket(socket(), ssl_version=ssl_version, + cert_reqs=cert_reqs, ca_certs=ca_certs) + s.connect(addr) + dercert = s.getpeercert(True) + s.close() + return DER_cert_to_PEM_cert(dercert) + +def get_protocol_name (protocol_code): + if protocol_code == PROTOCOL_TLSv1: + return "TLSv1" + elif protocol_code == PROTOCOL_SSLv23: + return "SSLv23" + elif protocol_code == PROTOCOL_SSLv2: + return "SSLv2" + elif protocol_code == PROTOCOL_SSLv3: + return "SSLv3" + else: + return "<unknown>" + + # a replacement for the old socket.ssl function def sslwrap_simple (sock, keyfile=None, certfile=None): |