diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
| -rw-r--r-- | Lib/test/test_ssl.py | 206 | 
1 files changed, 189 insertions, 17 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index a79fce6..25f3e4f 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -42,6 +42,9 @@ ONLYCERT = data_file("ssl_cert.pem")  ONLYKEY = data_file("ssl_key.pem")  BYTES_ONLYCERT = os.fsencode(ONLYCERT)  BYTES_ONLYKEY = os.fsencode(ONLYKEY) +CERTFILE_PROTECTED = data_file("keycert.passwd.pem") +ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem") +KEY_PASSWORD = "somepass"  CAPATH = data_file("capath")  BYTES_CAPATH = os.fsencode(CAPATH) @@ -103,6 +106,16 @@ class BasicSocketTests(unittest.TestCase):              sys.stdout.write("\n RAND_status is %d (%s)\n"                               % (v, (v and "sufficient randomness") or                                  "insufficient randomness")) + +        data, is_cryptographic = ssl.RAND_pseudo_bytes(16) +        self.assertEqual(len(data), 16) +        self.assertEqual(is_cryptographic, v == 1) +        if v: +            data = ssl.RAND_bytes(16) +            self.assertEqual(len(data), 16) +        else: +            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16) +          try:              ssl.RAND_egd(1)          except TypeError: @@ -337,6 +350,25 @@ class BasicSocketTests(unittest.TestCase):              self.assertRaises(ValueError, ctx.wrap_socket, sock, True,                                server_hostname="some.hostname") +    def test_unknown_channel_binding(self): +        # should raise ValueError for unknown type +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s) +        with self.assertRaises(ValueError): +            ss.get_channel_binding("unknown-type") + +    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, +                         "'tls-unique' channel binding not available") +    def test_tls_unique_channel_binding(self): +        # unconnected should return None for known type +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s) +        self.assertIsNone(ss.get_channel_binding("tls-unique")) +        # the same for server-side +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s, server_side=True, certfile=CERTFILE) +        self.assertIsNone(ss.get_channel_binding("tls-unique")) +  class ContextTests(unittest.TestCase):      @skip_if_broken_ubuntu_ssl @@ -427,6 +459,60 @@ class ContextTests(unittest.TestCase):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)          with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):              ctx.load_cert_chain(SVN_PYTHON_ORG_ROOT_CERT, ONLYKEY) +        # Password protected key and cert +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode()) +        ctx.load_cert_chain(CERTFILE_PROTECTED, +                            password=bytearray(KEY_PASSWORD.encode())) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode()) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, +                            bytearray(KEY_PASSWORD.encode())) +        with self.assertRaisesRegex(TypeError, "should be a string"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True) +        with self.assertRaises(ssl.SSLError): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass") +        with self.assertRaisesRegex(ValueError, "cannot be longer"): +            # openssl has a fixed limit on the password buffer. +            # PEM_BUFSIZE is generally set to 1kb. +            # Return a string larger than this. +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400) +        # Password callback +        def getpass_unicode(): +            return KEY_PASSWORD +        def getpass_bytes(): +            return KEY_PASSWORD.encode() +        def getpass_bytearray(): +            return bytearray(KEY_PASSWORD.encode()) +        def getpass_badpass(): +            return "badpass" +        def getpass_huge(): +            return b'a' * (1024 * 1024) +        def getpass_bad_type(): +            return 9 +        def getpass_exception(): +            raise Exception('getpass error') +        class GetPassCallable: +            def __call__(self): +                return KEY_PASSWORD +            def getpass(self): +                return KEY_PASSWORD +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable()) +        ctx.load_cert_chain(CERTFILE_PROTECTED, +                            password=GetPassCallable().getpass) +        with self.assertRaises(ssl.SSLError): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass) +        with self.assertRaisesRegex(ValueError, "cannot be longer"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge) +        with self.assertRaisesRegex(TypeError, "must return a string"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type) +        with self.assertRaisesRegex(Exception, "getpass error"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception) +        # Make sure the password function isn't called if it isn't needed +        ctx.load_cert_chain(CERTFILE, password=getpass_exception)      def test_load_verify_locations(self):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -671,25 +757,30 @@ class NetworkedTests(unittest.TestCase):                  sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)      def test_get_server_certificate(self): -        with support.transient_internet("svn.python.org"): -            pem = ssl.get_server_certificate(("svn.python.org", 443)) -            if not pem: -                self.fail("No server certificate on svn.python.org:443!") +        def _test_get_server_certificate(host, port, cert=None): +            with support.transient_internet(host): +                pem = ssl.get_server_certificate((host, port)) +                if not pem: +                    self.fail("No server certificate on %s:%s!" % (host, port)) -            try: -                pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) -            except ssl.SSLError as x: -                #should fail +                try: +                    pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE) +                except ssl.SSLError as x: +                    #should fail +                    if support.verbose: +                        sys.stdout.write("%s\n" % x) +                else: +                    self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) + +                pem = ssl.get_server_certificate((host, port), ca_certs=cert) +                if not pem: +                    self.fail("No server certificate on %s:%s!" % (host, port))                  if support.verbose: -                    sys.stdout.write("%s\n" % x) -            else: -                self.fail("Got server certificate %s for svn.python.org!" % pem) +                    sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem)) -            pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) -            if not pem: -                self.fail("No server certificate on svn.python.org:443!") -            if support.verbose: -                sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) +        _test_get_server_certificate('svn.python.org', 443, SVN_PYTHON_ORG_ROOT_CERT) +        if support.IPV6_ENABLED: +            _test_get_server_certificate('ipv6.google.com', 443)      def test_ciphers(self):          remote = ("svn.python.org", 443) @@ -837,6 +928,11 @@ else:                              self.sslconn = None                              if support.verbose and self.server.connectionchatty:                                  sys.stdout.write(" server: connection is now unencrypted...\n") +                        elif stripped == b'CB tls-unique': +                            if support.verbose and self.server.connectionchatty: +                                sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") +                            data = self.sslconn.get_channel_binding("tls-unique") +                            self.write(repr(data).encode("us-ascii") + b"\n")                          else:                              if (support.verbose and                                  self.server.connectionchatty): @@ -1248,7 +1344,8 @@ else:                  t.join()          @skip_if_broken_ubuntu_ssl -        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), "need SSLv2") +        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), +                             "OpenSSL is compiled without SSLv2 support")          def test_protocol_sslv2(self):              """Connecting to an SSLv2 server with various client options"""              if support.verbose: @@ -1580,6 +1677,14 @@ else:                          # consume data                          s.read() +                # Make sure sendmsg et al are disallowed to avoid +                # inadvertent disclosure of data and/or corruption +                # of the encrypted data stream +                self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) +                self.assertRaises(NotImplementedError, s.recvmsg, 100) +                self.assertRaises(NotImplementedError, +                                  s.recvmsg_into, bytearray(100)) +                  s.write(b"over\n")                  s.close()              finally: @@ -1635,6 +1740,73 @@ else:                  t.join()                  server.close() +        @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, +                             "'tls-unique' channel binding not available") +        def test_tls_unique_channel_binding(self): +            """Test tls-unique channel binding.""" +            if support.verbose: +                sys.stdout.write("\n") + +            server = ThreadedEchoServer(CERTFILE, +                                        certreqs=ssl.CERT_NONE, +                                        ssl_version=ssl.PROTOCOL_TLSv1, +                                        cacerts=CERTFILE, +                                        chatty=True, +                                        connectionchatty=False) +            flag = threading.Event() +            server.start(flag) +            # wait for it to start +            flag.wait() +            # try to connect +            s = ssl.wrap_socket(socket.socket(), +                                server_side=False, +                                certfile=CERTFILE, +                                ca_certs=CERTFILE, +                                cert_reqs=ssl.CERT_NONE, +                                ssl_version=ssl.PROTOCOL_TLSv1) +            s.connect((HOST, server.port)) +            try: +                # get the data +                cb_data = s.get_channel_binding("tls-unique") +                if support.verbose: +                    sys.stdout.write(" got channel binding data: {0!r}\n" +                                     .format(cb_data)) + +                # check if it is sane +                self.assertIsNotNone(cb_data) +                self.assertEqual(len(cb_data), 12) # True for TLSv1 + +                # and compare with the peers version +                s.write(b"CB tls-unique\n") +                peer_data_repr = s.read().strip() +                self.assertEqual(peer_data_repr, +                                 repr(cb_data).encode("us-ascii")) +                s.close() + +                # now, again +                s = ssl.wrap_socket(socket.socket(), +                                    server_side=False, +                                    certfile=CERTFILE, +                                    ca_certs=CERTFILE, +                                    cert_reqs=ssl.CERT_NONE, +                                    ssl_version=ssl.PROTOCOL_TLSv1) +                s.connect((HOST, server.port)) +                new_cb_data = s.get_channel_binding("tls-unique") +                if support.verbose: +                    sys.stdout.write(" got another channel binding data: {0!r}\n" +                                     .format(new_cb_data)) +                # is it really unique +                self.assertNotEqual(cb_data, new_cb_data) +                self.assertIsNotNone(cb_data) +                self.assertEqual(len(cb_data), 12) # True for TLSv1 +                s.write(b"CB tls-unique\n") +                peer_data_repr = s.read().strip() +                self.assertEqual(peer_data_repr, +                                 repr(new_cb_data).encode("us-ascii")) +                s.close() +            finally: +                server.stop() +                server.join()  def test_main(verbose=False):      if support.verbose:  | 
