diff options
author | Guido van Rossum <guido@python.org> | 2007-08-25 15:08:43 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2007-08-25 15:08:43 (GMT) |
commit | 4f2c3ddca45c11d466bf487d16d74fe875536e3f (patch) | |
tree | 494ac4ce52ddc06df41589ba3e0080ea48b5851c /Lib | |
parent | 1a42ece0c76166b1dead10decb0e54af084b4eb2 (diff) | |
download | cpython-4f2c3ddca45c11d466bf487d16d74fe875536e3f.zip cpython-4f2c3ddca45c11d466bf487d16d74fe875536e3f.tar.gz cpython-4f2c3ddca45c11d466bf487d16d74fe875536e3f.tar.bz2 |
Server-side SSL and certificate validation, by Bill Janssen.
While cleaning up Bill's C style, I may have cleaned up some code
he didn't touch as well (in _ssl.c).
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/socket.py | 7 | ||||
-rw-r--r-- | Lib/ssl.py | 252 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 304 |
3 files changed, 559 insertions, 4 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index 45a122f..48bb4f6 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -68,11 +68,10 @@ if _have_ssl: _realsocket = socket if _have_ssl: - _realssl = ssl def ssl(sock, keyfile=None, certfile=None): - if hasattr(sock, "_sock"): - sock = sock._sock - return _realssl(sock, keyfile, certfile) + import ssl as realssl + return realssl.sslwrap_simple(sock, keyfile, certfile) + __all__.append("ssl") # WSA error codes if sys.platform.lower().startswith("win"): diff --git a/Lib/ssl.py b/Lib/ssl.py new file mode 100644 index 0000000..17a48ea --- /dev/null +++ b/Lib/ssl.py @@ -0,0 +1,252 @@ +# Wrapper module for _ssl, providing some additional facilities +# implemented in Python. Written by Bill Janssen. + +"""\ +This module provides some more Pythonic support for SSL. + +Object types: + + sslsocket -- subtype of socket.socket which does SSL over the socket + +Exceptions: + + sslerror -- exception raised for I/O errors + +Functions: + + cert_time_to_seconds -- convert time string used for certificate + notBefore and notAfter functions to integer + seconds past the Epoch (the time values + returned from time.time()) + + fetch_server_certificate (HOST, PORT) -- fetch the certificate provided + by the server running on HOST at port PORT. No + validation of the certificate is performed. + +Integer constants: + +SSL_ERROR_ZERO_RETURN +SSL_ERROR_WANT_READ +SSL_ERROR_WANT_WRITE +SSL_ERROR_WANT_X509_LOOKUP +SSL_ERROR_SYSCALL +SSL_ERROR_SSL +SSL_ERROR_WANT_CONNECT + +SSL_ERROR_EOF +SSL_ERROR_INVALID_ERROR_CODE + +The following group define certificate requirements that one side is +allowing/requiring from the other side: + +CERT_NONE - no certificates from the other side are required (or will + be looked at if provided) +CERT_OPTIONAL - certificates are not required, but if provided will be + validated, and if validation fails, the connection will + also fail +CERT_REQUIRED - certificates are required, and will be validated, and + if validation fails, the connection will also fail + +The following constants identify various SSL protocol variants: + +PROTOCOL_SSLv2 +PROTOCOL_SSLv3 +PROTOCOL_SSLv23 +PROTOCOL_TLSv1 +""" + +import os, sys + +import _ssl # if we can't import it, let the error propagate +from socket import socket +from _ssl import sslerror +from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 + +# Root certs: +# +# The "ca_certs" argument to sslsocket() expects a file containing one or more +# certificates that are roots of various certificate signing chains. This file +# contains the certificates in PEM format (RFC ) where each certificate is +# encoded in base64 encoding and surrounded with a header and footer: +# -----BEGIN CERTIFICATE----- +# ... (CA certificate in base64 encoding) ... +# -----END CERTIFICATE----- +# The various certificates in the file are just concatenated together: +# -----BEGIN CERTIFICATE----- +# ... (CA certificate in base64 encoding) ... +# -----END CERTIFICATE----- +# -----BEGIN CERTIFICATE----- +# ... (a second CA certificate in base64 encoding) ... +# -----END CERTIFICATE----- +# +# Some "standard" root certificates are available at +# +# http://www.thawte.com/roots/ (for Thawte roots) +# http://www.verisign.com/support/roots.html (for Verisign) + +class sslsocket (socket): + + def __init__(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None): + socket.__init__(self, _sock=sock._sock) + if certfile and not keyfile: + keyfile = certfile + if server_side: + self._sslobj = _ssl.sslwrap(self._sock, 1, keyfile, certfile, + cert_reqs, ssl_version, ca_certs) + else: + # see if it's connected + try: + socket.getpeername(self) + # yes + self._sslobj = _ssl.sslwrap(self._sock, 0, keyfile, certfile, + cert_reqs, ssl_version, ca_certs) + except: + # no + self._sslobj = None + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + + def read(self, len=1024): + return self._sslobj.read(len) + + def write(self, data): + return self._sslobj.write(data) + + def getpeercert(self): + return self._sslobj.peer_certificate() + + def send (self, data, flags=0): + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to send() on %s" % + self.__class__) + return self._sslobj.write(data) + + def send_to (self, data, addr, flags=0): + raise ValueError("send_to not allowed on instances of %s" % + self.__class__) + + def sendall (self, data, flags=0): + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) + return self._sslobj.write(data) + + def recv (self, buflen=1024, flags=0): + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) + return self._sslobj.read(data, buflen) + + def recv_from (self, addr, buflen=1024, flags=0): + raise ValueError("recv_from not allowed on instances of %s" % + self.__class__) + + def shutdown(self): + if self._sslobj: + self._sslobj.shutdown() + self._sslobj = None + else: + socket.shutdown(self) + + def close(self): + if self._sslobj: + self.shutdown() + else: + socket.close(self) + + def connect(self, addr): + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._sslobj or (self.getsockname()[1] != 0): + raise ValueError("attempt to connect already-connected sslsocket!") + socket.connect(self, addr) + self._sslobj = _ssl.sslwrap(self._sock, 0, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs) + + def accept(self): + raise ValueError("accept() not supported on an sslsocket") + + +# some utility functions + +def cert_time_to_seconds(cert_time): + import time + return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) + +# a replacement for the old socket.ssl function + +def sslwrap_simple (sock, keyfile=None, certfile=None): + + return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, + PROTOCOL_SSLv23, None) + +# fetch the certificate that the server is providing in PEM form + +def fetch_server_certificate (host, port): + + import re, tempfile, os + + def subproc(cmd): + from subprocess import Popen, PIPE, STDOUT + proc = Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) + status = proc.wait() + output = proc.stdout.read() + return status, output + + def strip_to_x509_cert(certfile_contents, outfile=None): + m = re.search(r"^([-]+BEGIN CERTIFICATE[-]+[\r]*\n" + r".*[\r]*^[-]+END CERTIFICATE[-]+)$", + certfile_contents, re.MULTILINE | re.DOTALL) + if not m: + return None + else: + tn = tempfile.mktemp() + fp = open(tn, "w") + fp.write(m.group(1) + "\n") + fp.close() + try: + tn2 = (outfile or tempfile.mktemp()) + status, output = subproc(r'openssl x509 -in "%s" -out "%s"' % + (tn, tn2)) + if status != 0: + raise OperationError(status, tsig, output) + fp = open(tn2, 'rb') + data = fp.read() + fp.close() + os.unlink(tn2) + return data + finally: + os.unlink(tn) + + if sys.platform.startswith("win"): + tfile = tempfile.mktemp() + fp = open(tfile, "w") + fp.write("quit\n") + fp.close() + try: + status, output = subproc( + 'openssl s_client -connect "%s:%s" -showcerts < "%s"' % + (host, port, tfile)) + finally: + os.unlink(tfile) + else: + status, output = subproc( + 'openssl s_client -connect "%s:%s" -showcerts < /dev/null' % + (host, port)) + if status != 0: + raise OSError(status) + certtext = strip_to_x509_cert(output) + if not certtext: + raise ValueError("Invalid response received from server at %s:%s" % + (host, port)) + return certtext diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py new file mode 100644 index 0000000..0188dc2 --- /dev/null +++ b/Lib/test/test_ssl.py @@ -0,0 +1,304 @@ +# Test the support for SSL and sockets + +import sys +import unittest +from test import test_support +import socket +import errno +import threading +import subprocess +import time +import os +import pprint +import urllib +import shutil +import string +import traceback + +# Optionally test SSL support, if we have it in the tested platform +skip_expected = False +try: + import ssl +except ImportError: + skip_expected = True + +CERTFILE = None +GMAIL_POP_CERTFILE = None + +class BasicTests(unittest.TestCase): + + def testRudeShutdown(self): + # Some random port to connect to. + PORT = [9934] + + listener_ready = threading.Event() + listener_gone = threading.Event() + + # `listener` runs in a thread. It opens a socket listening on + # PORT, and sits in an accept() until the main thread connects. + # Then it rudely closes the socket, and sets Event `listener_gone` + # to let the main thread know the socket is gone. + def listener(): + s = socket.socket() + PORT[0] = test_support.bind_port(s, '', PORT[0]) + s.listen(5) + listener_ready.set() + s.accept() + s = None # reclaim the socket object, which also closes it + listener_gone.set() + + def connector(): + listener_ready.wait() + s = socket.socket() + s.connect(('localhost', PORT[0])) + listener_gone.wait() + try: + ssl_sock = socket.ssl(s) + except socket.sslerror: + pass + else: + raise test_support.TestFailed( + 'connecting to closed SSL socket should have failed') + + t = threading.Thread(target=listener) + t.start() + connector() + t.join() + + def testSSLconnect(self): + import os + with test_support.transient_internet(): + s = ssl.sslsocket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_NONE) + s.connect(("pop.gmail.com", 995)) + c = s.getpeercert() + if c: + raise test_support.TestFailed("Peer cert %s shouldn't be here!") + s.close() + + # this should fail because we have no verification certs + s = ssl.sslsocket(socket.socket(socket.AF_INET), + cert_reqs=ssl.CERT_REQUIRED) + try: + s.connect(("pop.gmail.com", 995)) + except ssl.sslerror: + pass + finally: + s.close() + +class ConnectedTests(unittest.TestCase): + + def testTLSecho (self): + + s1 = socket.socket() + s1.connect(('127.0.0.1', 10024)) + c1 = ssl.sslsocket(s1, ssl_version=ssl.PROTOCOL_TLSv1) + indata = "FOO\n" + c1.write(indata) + outdata = c1.read() + if outdata != indata.lower(): + sys.stderr.write("bad data <<%s>> received\n" % data) + c1.close() + + def testReadCert(self): + + s2 = socket.socket() + s2.connect(('127.0.0.1', 10024)) + c2 = ssl.sslsocket(s2, ssl_version=ssl.PROTOCOL_TLSv1, + cert_reqs=ssl.CERT_REQUIRED, ca_certs=CERTFILE) + cert = c2.getpeercert() + if not cert: + raise test_support.TestFailed("Can't get peer certificate.") + if not cert.has_key('subject'): + raise test_support.TestFailed( + "No subject field in certificate: %s." % + pprint.pformat(cert)) + if not (cert['subject'].has_key('organizationName')): + raise test_support.TestFailed( + "No 'organizationName' field in certificate subject: %s." % + pprint.pformat(cert)) + if (cert['subject']['organizationName'] != + "Python Software Foundation"): + raise test_support.TestFailed( + "Invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'."); + c2.close() + + +class threadedEchoServer(threading.Thread): + + class connectionHandler(threading.Thread): + + def __init__(self, server, connsock): + self.server = server + self.running = False + self.sock = connsock + threading.Thread.__init__(self) + self.setDaemon(True) + + def run (self): + self.running = True + sslconn = ssl.sslsocket(self.sock, server_side=True, + certfile=self.server.certificate, + ssl_version=self.server.protocol, + cert_reqs=self.server.certreqs) + while self.running: + try: + msg = sslconn.read() + if not msg: + # eof, so quit this handler + self.running = False + sslconn.close() + elif msg.strip() == 'over': + sslconn.close() + self.server.stop() + self.running = False + else: + # print "server:", msg.strip().lower() + sslconn.write(msg.lower()) + except ssl.sslerror: + sys.stderr.write(string.join( + traceback.format_exception(*sys.exc_info()))) + sslconn.close() + self.running = False + except: + sys.stderr.write(string.join( + traceback.format_exception(*sys.exc_info()))) + + def __init__(self, port, certificate, ssl_version=ssl.PROTOCOL_TLSv1, + certreqs=ssl.CERT_NONE, cacerts=None): + self.certificate = certificate + self.protocol = ssl_version + self.certreqs = certreqs + self.cacerts = cacerts + self.sock = socket.socket() + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + self.sock.bind(('127.0.0.1', port)) + self.active = False + threading.Thread.__init__(self) + self.setDaemon(False) + + def run (self): + self.sock.settimeout(0.5) + self.sock.listen(5) + self.active = True + while self.active: + try: + newconn, connaddr = self.sock.accept() + # sys.stderr.write('new connection from ' + str(connaddr)) + handler = self.connectionHandler(self, newconn) + handler.start() + except socket.timeout: + pass + except KeyboardInterrupt: + self.active = False + except: + sys.stderr.write(string.join( + traceback.format_exception(*sys.exc_info()))) + + def stop (self): + self.active = False + + +CERTFILE_CONFIG_TEMPLATE = """ +# create RSA certs - Server + +[ req ] +default_bits = 1024 +encrypt_key = yes +distinguished_name = req_dn +x509_extensions = cert_type + +[ req_dn ] +countryName = Country Name (2 letter code) +countryName_default = US +countryName_min = 2 +countryName_max = 2 + +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = %(state)s + +localityName = Locality Name (eg, city) +localityName_default = %(city)s + +0.organizationName = Organization Name (eg, company) +0.organizationName_default = %(organization)s + +organizationalUnitName = Organizational Unit Name (eg, section) +organizationalUnitName_default = %(unit)s + +0.commonName = Common Name (FQDN of your server) +0.commonName_default = %(common-name)s + +# To create a certificate for more than one name uncomment: +# 1.commonName = DNS alias of your server +# 2.commonName = DNS alias of your server +# ... +# See http://home.netscape.com/eng/security/ssl_2.0_certificate.html +# to see how Netscape understands commonName. + +[ cert_type ] +nsCertType = server +""" + +def create_cert_files(): + + import tempfile, socket, os + d = tempfile.mkdtemp() + # now create a configuration file for the CA signing cert + fqdn = socket.getfqdn() + crtfile = os.path.join(d, "cert.pem") + conffile = os.path.join(d, "ca.conf") + fp = open(conffile, "w") + fp.write(CERTFILE_CONFIG_TEMPLATE % + {'state': "Delaware", + 'city': "Wilmington", + 'organization': "Python Software Foundation", + 'unit': "SSL", + 'common-name': fqdn, + }) + fp.close() + os.system( + "openssl req -batch -new -x509 -days 10 -nodes -config %s " + "-keyout \"%s\" -out \"%s\" > /dev/null < /dev/null 2>&1" % + (conffile, crtfile, crtfile)) + # now we have a self-signed server cert in crtfile + os.unlink(conffile) + #sf_certfile = os.path.join(d, "sourceforge-imap.pem") + #sf_cert = ssl.fetch_server_certificate('pop.gmail.com', 995) + #open(sf_certfile, 'w').write(sf_cert) + #return d, crtfile, sf_certfile + # sys.stderr.write(open(crtfile, 'r').read() + '\n') + return d, crtfile + +def test_main(): + if skip_expected: + raise test_support.TestSkipped("socket module has no ssl support") + + global CERTFILE + tdir, CERTFILE = create_cert_files() + + tests = [BasicTests] + + server = None + if test_support.is_resource_enabled('network'): + server = threadedEchoServer(10024, CERTFILE) + server.start() + time.sleep(1) + tests.append(ConnectedTests) + + thread_info = test_support.threading_setup() + + try: + test_support.run_unittest(*tests) + finally: + if server is not None and server.active: + server.stop() + # wait for it to stop + server.join() + + shutil.rmtree(tdir) + test_support.threading_cleanup(*thread_info) + +if __name__ == "__main__": + test_main() |