diff options
author | Christian Heimes <christian@python.org> | 2019-05-31 09:44:05 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-31 09:44:05 (GMT) |
commit | c7f7069e77c58e83b847c0bfe4d5aadf6add2e68 (patch) | |
tree | 306bee26619ebc132be4b98fd60d0daf79964cf0 /Lib/test/test_ssl.py | |
parent | e9b51c0ad81da1da11ae65840ac8b50a8521373c (diff) | |
download | cpython-c7f7069e77c58e83b847c0bfe4d5aadf6add2e68.zip cpython-c7f7069e77c58e83b847c0bfe4d5aadf6add2e68.tar.gz cpython-c7f7069e77c58e83b847c0bfe4d5aadf6add2e68.tar.bz2 |
bpo-34271: Add ssl debugging helpers (GH-10031)
The ssl module now can dump key material to a keylog file and trace TLS
protocol messages with a tracing callback. The default and stdlib
contexts also support SSLKEYLOGFILE env var.
The msg_callback and related enums are private members. The feature
is designed for internal debugging and not for end users.
Signed-off-by: Christian Heimes <christian@python.org>
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 168 |
1 files changed, 167 insertions, 1 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d48d6e5..f368906 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2,6 +2,7 @@ import sys import unittest +import unittest.mock from test import support import socket import select @@ -25,6 +26,7 @@ except ImportError: ssl = support.import_module("ssl") +from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST @@ -4405,6 +4407,170 @@ class TestPostHandshakeAuth(unittest.TestCase): self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024)) +HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename') +requires_keylog = unittest.skipUnless( + HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback') + +class TestSSLDebug(unittest.TestCase): + + def keylog_lines(self, fname=support.TESTFN): + with open(fname) as f: + return len(list(f)) + + @requires_keylog + def test_keylog_defaults(self): + self.addCleanup(support.unlink, support.TESTFN) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.keylog_filename, None) + + self.assertFalse(os.path.isfile(support.TESTFN)) + ctx.keylog_filename = support.TESTFN + self.assertEqual(ctx.keylog_filename, support.TESTFN) + self.assertTrue(os.path.isfile(support.TESTFN)) + self.assertEqual(self.keylog_lines(), 1) + + ctx.keylog_filename = None + self.assertEqual(ctx.keylog_filename, None) + + with self.assertRaises((IsADirectoryError, PermissionError)): + # Windows raises PermissionError + ctx.keylog_filename = os.path.dirname( + os.path.abspath(support.TESTFN)) + + with self.assertRaises(TypeError): + ctx.keylog_filename = 1 + + @requires_keylog + def test_keylog_filename(self): + self.addCleanup(support.unlink, support.TESTFN) + client_context, server_context, hostname = testing_context() + + client_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + # header, 5 lines for TLS 1.3 + self.assertEqual(self.keylog_lines(), 6) + + client_context.keylog_filename = None + server_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + self.assertGreaterEqual(self.keylog_lines(), 11) + + client_context.keylog_filename = support.TESTFN + server_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + self.assertGreaterEqual(self.keylog_lines(), 21) + + client_context.keylog_filename = None + server_context.keylog_filename = None + + @requires_keylog + @unittest.skipIf(sys.flags.ignore_environment, + "test is not compatible with ignore_environment") + def test_keylog_env(self): + self.addCleanup(support.unlink, support.TESTFN) + with unittest.mock.patch.dict(os.environ): + os.environ['SSLKEYLOGFILE'] = support.TESTFN + self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.keylog_filename, None) + + ctx = ssl.create_default_context() + self.assertEqual(ctx.keylog_filename, support.TESTFN) + + ctx = ssl._create_stdlib_context() + self.assertEqual(ctx.keylog_filename, support.TESTFN) + + def test_msg_callback(self): + client_context, server_context, hostname = testing_context() + + def msg_cb(conn, direction, version, content_type, msg_type, data): + pass + + self.assertIs(client_context._msg_callback, None) + client_context._msg_callback = msg_cb + self.assertIs(client_context._msg_callback, msg_cb) + with self.assertRaises(TypeError): + client_context._msg_callback = object() + + def test_msg_callback_tls12(self): + client_context, server_context, hostname = testing_context() + client_context.options |= ssl.OP_NO_TLSv1_3 + + msg = [] + + def msg_cb(conn, direction, version, content_type, msg_type, data): + self.assertIsInstance(conn, ssl.SSLSocket) + self.assertIsInstance(data, bytes) + self.assertIn(direction, {'read', 'write'}) + msg.append((direction, version, content_type, msg_type)) + + client_context._msg_callback = msg_cb + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + + self.assertEqual(msg, [ + ("write", TLSVersion.TLSv1, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.CLIENT_HELLO), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.SERVER_HELLO), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.CERTIFICATE), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.SERVER_KEY_EXCHANGE), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.SERVER_DONE), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.CLIENT_KEY_EXCHANGE), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.FINISHED), + ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC, + _TLSMessageType.CHANGE_CIPHER_SPEC), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.FINISHED), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.NEWSESSION_TICKET), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.FINISHED), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER, + _TLSMessageType.CERTIFICATE_STATUS), + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.FINISHED), + ]) + + def test_main(verbose=False): if support.verbose: import warnings @@ -4440,7 +4606,7 @@ def test_main(verbose=False): tests = [ ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests, SSLObjectTests, SimpleBackgroundTests, ThreadedTests, - TestPostHandshakeAuth + TestPostHandshakeAuth, TestSSLDebug ] if support.is_resource_enabled('network'): |