diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 301 |
1 files changed, 264 insertions, 37 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index c6c26bc..86d1cd8 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -42,6 +42,9 @@ ONLYCERT = data_file("ssl_cert.pem") ONLYKEY = data_file("ssl_key.pem") BYTES_ONLYCERT = os.fsencode(ONLYCERT) BYTES_ONLYKEY = os.fsencode(ONLYKEY) +CERTFILE_PROTECTED = data_file("keycert.passwd.pem") +ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem") +KEY_PASSWORD = "somepass" CAPATH = data_file("capath") BYTES_CAPATH = os.fsencode(CAPATH) @@ -53,6 +56,8 @@ WRONGCERT = data_file("XXXnonexisting.pem") BADKEY = data_file("badkey.pem") NOKIACERT = data_file("nokia.pem") +DHFILE = data_file("dh512.pem") +BYTES_DHFILE = os.fsencode(DHFILE) def handle_error(prefix): exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) @@ -95,7 +100,13 @@ class BasicSocketTests(unittest.TestCase): ssl.CERT_NONE ssl.CERT_OPTIONAL ssl.CERT_REQUIRED + ssl.OP_CIPHER_SERVER_PREFERENCE + ssl.OP_SINGLE_DH_USE + ssl.OP_SINGLE_ECDH_USE + if ssl.OPENSSL_VERSION_INFO >= (1, 0): + ssl.OP_NO_COMPRESSION self.assertIn(ssl.HAS_SNI, {True, False}) + self.assertIn(ssl.HAS_ECDH, {True, False}) def test_random(self): v = ssl.RAND_status() @@ -103,6 +114,16 @@ class BasicSocketTests(unittest.TestCase): sys.stdout.write("\n RAND_status is %d (%s)\n" % (v, (v and "sufficient randomness") or "insufficient randomness")) + + data, is_cryptographic = ssl.RAND_pseudo_bytes(16) + self.assertEqual(len(data), 16) + self.assertEqual(is_cryptographic, v == 1) + if v: + data = ssl.RAND_bytes(16) + self.assertEqual(len(data), 16) + else: + self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16) + try: ssl.RAND_egd(1) except TypeError: @@ -337,6 +358,25 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(ValueError, ctx.wrap_socket, sock, True, server_hostname="some.hostname") + def test_unknown_channel_binding(self): + # should raise ValueError for unknown type + s = socket.socket(socket.AF_INET) + ss = ssl.wrap_socket(s) + with self.assertRaises(ValueError): + ss.get_channel_binding("unknown-type") + + @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, + "'tls-unique' channel binding not available") + def test_tls_unique_channel_binding(self): + # unconnected should return None for known type + s = socket.socket(socket.AF_INET) + ss = ssl.wrap_socket(s) + self.assertIsNone(ss.get_channel_binding("tls-unique")) + # the same for server-side + s = socket.socket(socket.AF_INET) + ss = ssl.wrap_socket(s, server_side=True, certfile=CERTFILE) + self.assertIsNone(ss.get_channel_binding("tls-unique")) + class ContextTests(unittest.TestCase): @skip_if_broken_ubuntu_ssl @@ -427,6 +467,60 @@ class ContextTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"): ctx.load_cert_chain(SVN_PYTHON_ORG_ROOT_CERT, ONLYKEY) + # Password protected key and cert + ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD) + ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode()) + ctx.load_cert_chain(CERTFILE_PROTECTED, + password=bytearray(KEY_PASSWORD.encode())) + ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD) + ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode()) + ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, + bytearray(KEY_PASSWORD.encode())) + with self.assertRaisesRegex(TypeError, "should be a string"): + ctx.load_cert_chain(CERTFILE_PROTECTED, password=True) + with self.assertRaises(ssl.SSLError): + ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass") + with self.assertRaisesRegex(ValueError, "cannot be longer"): + # openssl has a fixed limit on the password buffer. + # PEM_BUFSIZE is generally set to 1kb. + # Return a string larger than this. + ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400) + # Password callback + def getpass_unicode(): + return KEY_PASSWORD + def getpass_bytes(): + return KEY_PASSWORD.encode() + def getpass_bytearray(): + return bytearray(KEY_PASSWORD.encode()) + def getpass_badpass(): + return "badpass" + def getpass_huge(): + return b'a' * (1024 * 1024) + def getpass_bad_type(): + return 9 + def getpass_exception(): + raise Exception('getpass error') + class GetPassCallable: + def __call__(self): + return KEY_PASSWORD + def getpass(self): + return KEY_PASSWORD + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode) + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes) + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray) + ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable()) + ctx.load_cert_chain(CERTFILE_PROTECTED, + password=GetPassCallable().getpass) + with self.assertRaises(ssl.SSLError): + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass) + with self.assertRaisesRegex(ValueError, "cannot be longer"): + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge) + with self.assertRaisesRegex(TypeError, "must return a string"): + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type) + with self.assertRaisesRegex(Exception, "getpass error"): + ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception) + # Make sure the password function isn't called if it isn't needed + ctx.load_cert_chain(CERTFILE, password=getpass_exception) def test_load_verify_locations(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -447,6 +541,19 @@ class ContextTests(unittest.TestCase): # Issue #10989: crash if the second argument type is invalid self.assertRaises(TypeError, ctx.load_verify_locations, None, True) + def test_load_dh_params(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_dh_params(DHFILE) + if os.name != 'nt': + ctx.load_dh_params(BYTES_DHFILE) + self.assertRaises(TypeError, ctx.load_dh_params) + self.assertRaises(TypeError, ctx.load_dh_params, None) + with self.assertRaises(FileNotFoundError) as cm: + ctx.load_dh_params(WRONGCERT) + self.assertEqual(cm.exception.errno, errno.ENOENT) + with self.assertRaisesRegex(ssl.SSLError, "PEM routines"): + ctx.load_dh_params(CERTFILE) + @skip_if_broken_ubuntu_ssl def test_session_stats(self): for proto in PROTOCOLS: @@ -471,6 +578,16 @@ class ContextTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx.set_default_verify_paths() + @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build") + def test_set_ecdh_curve(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.set_ecdh_curve("prime256v1") + ctx.set_ecdh_curve(b"prime256v1") + self.assertRaises(TypeError, ctx.set_ecdh_curve) + self.assertRaises(TypeError, ctx.set_ecdh_curve, None) + self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo") + self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo") + class NetworkedTests(unittest.TestCase): @@ -533,13 +650,10 @@ class NetworkedTests(unittest.TestCase): try: s.do_handshake() break - except ssl.SSLError as err: - if err.args[0] == ssl.SSL_ERROR_WANT_READ: - select.select([s], [], [], 5.0) - elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: - select.select([], [s], [], 5.0) - else: - raise + except ssl.SSLWantReadError: + select.select([s], [], [], 5.0) + except ssl.SSLWantWriteError: + select.select([], [s], [], 5.0) # SSL established self.assertTrue(s.getpeercert()) finally: @@ -659,37 +773,39 @@ class NetworkedTests(unittest.TestCase): count += 1 s.do_handshake() break - except ssl.SSLError as err: - if err.args[0] == ssl.SSL_ERROR_WANT_READ: - select.select([s], [], []) - elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: - select.select([], [s], []) - else: - raise + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) s.close() if support.verbose: sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) def test_get_server_certificate(self): - with support.transient_internet("svn.python.org"): - pem = ssl.get_server_certificate(("svn.python.org", 443)) - if not pem: - self.fail("No server certificate on svn.python.org:443!") + def _test_get_server_certificate(host, port, cert=None): + with support.transient_internet(host): + pem = ssl.get_server_certificate((host, port)) + if not pem: + self.fail("No server certificate on %s:%s!" % (host, port)) - try: - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) - except ssl.SSLError as x: - #should fail + try: + pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE) + except ssl.SSLError as x: + #should fail + if support.verbose: + sys.stdout.write("%s\n" % x) + else: + self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) + + pem = ssl.get_server_certificate((host, port), ca_certs=cert) + if not pem: + self.fail("No server certificate on %s:%s!" % (host, port)) if support.verbose: - sys.stdout.write("%s\n" % x) - else: - self.fail("Got server certificate %s for svn.python.org!" % pem) + sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem)) - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) - if not pem: - self.fail("No server certificate on svn.python.org:443!") - if support.verbose: - sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) + _test_get_server_certificate('svn.python.org', 443, SVN_PYTHON_ORG_ROOT_CERT) + if support.IPV6_ENABLED: + _test_get_server_certificate('ipv6.google.com', 443) def test_ciphers(self): remote = ("svn.python.org", 443) @@ -838,6 +954,11 @@ else: self.sslconn = None if support.verbose and self.server.connectionchatty: sys.stdout.write(" server: connection is now unencrypted...\n") + elif stripped == b'CB tls-unique': + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") + data = self.sslconn.get_channel_binding("tls-unique") + self.write(repr(data).encode("us-ascii") + b"\n") else: if (support.verbose and self.server.connectionchatty): @@ -945,12 +1066,11 @@ else: def _do_ssl_handshake(self): try: self.socket.do_handshake() - except ssl.SSLError as err: - if err.args[0] in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): - return - elif err.args[0] == ssl.SSL_ERROR_EOF: - return self.handle_close() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + return + except ssl.SSLEOFError: + return self.handle_close() + except ssl.SSLError: raise except socket.error as err: if err.args[0] == errno.ECONNABORTED: @@ -1098,7 +1218,12 @@ else: if connectionchatty: if support.verbose: sys.stdout.write(" client: closing connection.\n") + stats = { + 'compression': s.compression(), + 'cipher': s.cipher(), + } s.close() + return stats def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): @@ -1250,7 +1375,8 @@ else: t.join() @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), "need SSLv2") + @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), + "OpenSSL is compiled without SSLv2 support") def test_protocol_sslv2(self): """Connecting to an SSLv2 server with various client options""" if support.verbose: @@ -1556,6 +1682,15 @@ else: ) # consume data s.read() + + # Make sure sendmsg et al are disallowed to avoid + # inadvertent disclosure of data and/or corruption + # of the encrypted data stream + self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) + self.assertRaises(NotImplementedError, s.recvmsg, 100) + self.assertRaises(NotImplementedError, + s.recvmsg_into, bytearray(100)) + s.write(b"over\n") s.close() @@ -1624,6 +1759,98 @@ else: s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) + @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, + "'tls-unique' channel binding not available") + def test_tls_unique_channel_binding(self): + """Test tls-unique channel binding.""" + if support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # get the data + cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got channel binding data: {0!r}\n" + .format(cb_data)) + + # check if it is sane + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + + # and compare with the peers version + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(cb_data).encode("us-ascii")) + s.close() + + # now, again + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + new_cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got another channel binding data: {0!r}\n" + .format(new_cb_data)) + # is it really unique + self.assertNotEqual(cb_data, new_cb_data) + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(new_cb_data).encode("us-ascii")) + s.close() + + def test_compression(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + if support.verbose: + sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) + self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) + + @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), + "ssl.OP_NO_COMPRESSION needed for this test") + def test_compression_disabled(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.options |= ssl.OP_NO_COMPRESSION + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['compression'], None) + + def test_dh_params(self): + # Check we can get a connection with ephemeral Diffie-Hellman + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.load_dh_params(DHFILE) + context.set_ciphers("kEDH") + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + cipher = stats["cipher"][0] + parts = cipher.split("-") + if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: + self.fail("Non-DH cipher: " + cipher[0]) + def test_main(verbose=False): if support.verbose: |