summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py193
1 files changed, 192 insertions, 1 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index b4cafc1..6fd2002 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -218,7 +218,7 @@ def testing_context(server_cert=SIGNED_CERTFILE):
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(server_cert)
- client_context.load_verify_locations(SIGNING_CA)
+ server_context.load_verify_locations(SIGNING_CA)
return client_context, server_context, hostname
@@ -2262,6 +2262,23 @@ class ThreadedEchoServer(threading.Thread):
sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
data = self.sslconn.get_channel_binding("tls-unique")
self.write(repr(data).encode("us-ascii") + b"\n")
+ elif stripped == b'PHA':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: initiating post handshake auth\n")
+ try:
+ self.sslconn.verify_client_post_handshake()
+ except ssl.SSLError as e:
+ self.write(repr(e).encode("us-ascii") + b"\n")
+ else:
+ self.write(b"OK\n")
+ elif stripped == b'HASCERT':
+ if self.sslconn.getpeercert() is not None:
+ self.write(b'TRUE\n')
+ else:
+ self.write(b'FALSE\n')
+ elif stripped == b'GETCERT':
+ cert = self.sslconn.getpeercert()
+ self.write(repr(cert).encode("us-ascii") + b"\n")
else:
if (support.verbose and
self.server.connectionchatty):
@@ -4148,6 +4165,179 @@ class ThreadedTests(unittest.TestCase):
'Session refers to a different SSLContext.')
+@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
+class TestPostHandshakeAuth(unittest.TestCase):
+ def test_pha_setter(self):
+ protocols = [
+ ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
+ ]
+ for protocol in protocols:
+ ctx = ssl.SSLContext(protocol)
+ self.assertEqual(ctx.post_handshake_auth, False)
+
+ ctx.post_handshake_auth = True
+ self.assertEqual(ctx.post_handshake_auth, True)
+
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+ self.assertEqual(ctx.post_handshake_auth, True)
+
+ ctx.post_handshake_auth = False
+ self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+ self.assertEqual(ctx.post_handshake_auth, False)
+
+ ctx.verify_mode = ssl.CERT_OPTIONAL
+ ctx.post_handshake_auth = True
+ self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
+ self.assertEqual(ctx.post_handshake_auth, True)
+
+ def test_pha_required(self):
+ client_context, server_context, hostname = testing_context()
+ server_context.post_handshake_auth = True
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.post_handshake_auth = True
+ client_context.load_cert_chain(SIGNED_CERTFILE)
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'FALSE\n')
+ s.write(b'PHA')
+ self.assertEqual(s.recv(1024), b'OK\n')
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'TRUE\n')
+ # PHA method just returns true when cert is already available
+ s.write(b'PHA')
+ self.assertEqual(s.recv(1024), b'OK\n')
+ s.write(b'GETCERT')
+ cert_text = s.recv(4096).decode('us-ascii')
+ self.assertIn('Python Software Foundation CA', cert_text)
+
+ def test_pha_required_nocert(self):
+ client_context, server_context, hostname = testing_context()
+ server_context.post_handshake_auth = True
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.post_handshake_auth = True
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ s.write(b'PHA')
+ # receive CertificateRequest
+ self.assertEqual(s.recv(1024), b'OK\n')
+ # send empty Certificate + Finish
+ s.write(b'HASCERT')
+ # receive alert
+ with self.assertRaisesRegex(
+ ssl.SSLError,
+ 'tlsv13 alert certificate required'):
+ s.recv(1024)
+
+ def test_pha_optional(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ client_context, server_context, hostname = testing_context()
+ server_context.post_handshake_auth = True
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.post_handshake_auth = True
+ client_context.load_cert_chain(SIGNED_CERTFILE)
+
+ # check CERT_OPTIONAL
+ server_context.verify_mode = ssl.CERT_OPTIONAL
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'FALSE\n')
+ s.write(b'PHA')
+ self.assertEqual(s.recv(1024), b'OK\n')
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'TRUE\n')
+
+ def test_pha_optional_nocert(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ client_context, server_context, hostname = testing_context()
+ server_context.post_handshake_auth = True
+ server_context.verify_mode = ssl.CERT_OPTIONAL
+ client_context.post_handshake_auth = True
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'FALSE\n')
+ s.write(b'PHA')
+ self.assertEqual(s.recv(1024), b'OK\n')
+ # optional doens't fail when client does not have a cert
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'FALSE\n')
+
+ def test_pha_no_pha_client(self):
+ client_context, server_context, hostname = testing_context()
+ server_context.post_handshake_auth = True
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_cert_chain(SIGNED_CERTFILE)
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ with self.assertRaisesRegex(ssl.SSLError, 'not server'):
+ s.verify_client_post_handshake()
+ s.write(b'PHA')
+ self.assertIn(b'extension not received', s.recv(1024))
+
+ def test_pha_no_pha_server(self):
+ # server doesn't have PHA enabled, cert is requested in handshake
+ client_context, server_context, hostname = testing_context()
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.post_handshake_auth = True
+ client_context.load_cert_chain(SIGNED_CERTFILE)
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'TRUE\n')
+ # PHA doesn't fail if there is already a cert
+ s.write(b'PHA')
+ self.assertEqual(s.recv(1024), b'OK\n')
+ s.write(b'HASCERT')
+ self.assertEqual(s.recv(1024), b'TRUE\n')
+
+ def test_pha_not_tls13(self):
+ # TLS 1.2
+ client_context, server_context, hostname = testing_context()
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.maximum_version = ssl.TLSVersion.TLSv1_2
+ client_context.post_handshake_auth = True
+ client_context.load_cert_chain(SIGNED_CERTFILE)
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ # PHA fails for TLS != 1.3
+ s.write(b'PHA')
+ self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
+
+
def test_main(verbose=False):
if support.verbose:
import warnings
@@ -4183,6 +4373,7 @@ def test_main(verbose=False):
tests = [
ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
+ TestPostHandshakeAuth
]
if support.is_resource_enabled('network'):