diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 27 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 54 |
2 files changed, 74 insertions, 7 deletions
@@ -90,7 +90,7 @@ from _ssl import ( SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, ) -from _ssl import HAS_SNI, HAS_ECDH +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1) from _ssl import _OPENSSL_API_VERSION @@ -209,6 +209,17 @@ class SSLContext(_SSLContext): server_hostname=server_hostname, _context=self) + def set_npn_protocols(self, npn_protocols): + protos = bytearray() + for protocol in npn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('NPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_npn_protocols(protos) + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps @@ -220,7 +231,7 @@ class SSLSocket(socket): ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True, ciphers=None, + suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, server_hostname=None, _context=None): @@ -240,6 +251,8 @@ class SSLSocket(socket): self.context.load_verify_locations(ca_certs) if certfile: self.context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.keyfile = keyfile @@ -340,6 +353,13 @@ class SSLSocket(socket): self._checkClosed() return self._sslobj.peer_certificate(binary_form) + def selected_npn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_NPN: + return None + else: + return self._sslobj.selected_npn_protocol() + def cipher(self): self._checkClosed() if not self._sslobj: @@ -568,7 +588,8 @@ def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): + suppress_ragged_eofs=True, + ciphers=None): return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index c6ce075..ada3c4b 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -879,6 +879,7 @@ else: try: self.sslconn = self.server.context.wrap_socket( self.sock, server_side=True) + self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) except ssl.SSLError as e: # XXX Various errors can have happened here, for example # a mismatching protocol version, an invalid certificate, @@ -901,6 +902,8 @@ else: cipher = self.sslconn.cipher() if support.verbose and self.server.chatty: sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + sys.stdout.write(" server: selected protocol is now " + + str(self.sslconn.selected_npn_protocol()) + "\n") return True def read(self): @@ -979,7 +982,7 @@ else: def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - ciphers=None, context=None): + npn_protocols=None, ciphers=None, context=None): if context: self.context = context else: @@ -992,6 +995,8 @@ else: self.context.load_verify_locations(cacerts) if certificate: self.context.load_cert_chain(certificate) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.chatty = chatty @@ -1001,6 +1006,7 @@ else: self.port = support.bind_port(self.sock) self.flag = None self.active = False + self.selected_protocols = [] self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True @@ -1195,6 +1201,7 @@ else: Launch a server, connect a client to it and try various reads and writes. """ + stats = {} server = ThreadedEchoServer(context=server_context, chatty=chatty, connectionchatty=False) @@ -1220,12 +1227,14 @@ else: if connectionchatty: if support.verbose: sys.stdout.write(" client: closing connection.\n") - stats = { + stats.update({ 'compression': s.compression(), 'cipher': s.cipher(), - } + 'client_npn_protocol': s.selected_npn_protocol() + }) s.close() - return stats + stats['server_npn_protocols'] = server.selected_protocols + return stats def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): @@ -1853,6 +1862,43 @@ else: if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: self.fail("Non-DH cipher: " + cipher[0]) + def test_selected_npn_protocol(self): + # selected_npn_protocol() is None unless NPN is used + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_npn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") + def test_npn_protocols(self): + server_protocols = ['http/1.1', 'spdy/2'] + protocol_tests = [ + (['http/1.1', 'spdy/2'], 'http/1.1'), + (['spdy/2', 'http/1.1'], 'http/1.1'), + (['spdy/2', 'test'], 'spdy/2'), + (['abc', 'def'], 'abc') + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_npn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_cert_chain(CERTFILE) + client_context.set_npn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_npn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_npn_protocols'][-1] \ + if len(stats['server_npn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + def test_main(verbose=False): if support.verbose: |