summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py64
1 files changed, 60 insertions, 4 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index b7504c6..30af08d 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1761,7 +1761,8 @@ else:
try:
self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True)
- self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
except (ssl.SSLError, ConnectionResetError) as e:
# We treat ConnectionResetError as though it were an
# SSLError - OpenSSL on Ubuntu abruptly closes the
@@ -1869,7 +1870,8 @@ else:
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- npn_protocols=None, ciphers=None, context=None):
+ npn_protocols=None, alpn_protocols=None,
+ ciphers=None, context=None):
if context:
self.context = context
else:
@@ -1884,6 +1886,8 @@ else:
self.context.load_cert_chain(certificate)
if npn_protocols:
self.context.set_npn_protocols(npn_protocols)
+ if alpn_protocols:
+ self.context.set_alpn_protocols(alpn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.chatty = chatty
@@ -1893,7 +1897,8 @@ else:
self.port = support.bind_port(self.sock)
self.flag = None
self.active = False
- self.selected_protocols = []
+ self.selected_npn_protocols = []
+ self.selected_alpn_protocols = []
self.shared_ciphers = []
self.conn_errors = []
threading.Thread.__init__(self)
@@ -2120,11 +2125,13 @@ else:
'compression': s.compression(),
'cipher': s.cipher(),
'peercert': s.getpeercert(),
+ 'client_alpn_protocol': s.selected_alpn_protocol(),
'client_npn_protocol': s.selected_npn_protocol(),
'version': s.version(),
})
s.close()
- stats['server_npn_protocols'] = server.selected_protocols
+ stats['server_alpn_protocols'] = server.selected_alpn_protocols
+ stats['server_npn_protocols'] = server.selected_npn_protocols
stats['server_shared_ciphers'] = server.shared_ciphers
return stats
@@ -3022,6 +3029,55 @@ else:
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
self.fail("Non-DH cipher: " + cipher[0])
+ def test_selected_alpn_protocol(self):
+ # selected_alpn_protocol() is None unless ALPN is used.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+ def test_selected_alpn_protocol_if_server_uses_alpn(self):
+ # selected_alpn_protocol() is None unless ALPN is used by the client.
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_verify_locations(CERTFILE)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(['foo', 'bar'])
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+ def test_alpn_protocols(self):
+ server_protocols = ['foo', 'bar', 'milkshake']
+ protocol_tests = [
+ (['foo', 'bar'], 'foo'),
+ (['bar', 'foo'], 'bar'),
+ (['milkshake'], 'milkshake'),
+ (['http/3.0', 'http/4.0'], 'foo')
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_alpn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_alpn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_alpn_protocols'][-1] \
+ if len(stats['server_alpn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)