diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
| -rw-r--r-- | Lib/test/test_ssl.py | 477 | 
1 files changed, 440 insertions, 37 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 779b622..ea619fd 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -86,6 +86,12 @@ def have_verify_flags():      # 0.9.8 or higher      return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) +def utc_offset(): #NOTE: ignore issues like #1647654 +    # local time = utc time + utc offset +    if time.daylight and time.localtime().tm_isdst > 0: +        return -time.altzone  # seconds +    return -time.timezone +  def asn1time(cert_time):      # Some versions of OpenSSL ignore seconds, see #18207      # 0.9.8.i @@ -134,6 +140,14 @@ class BasicSocketTests(unittest.TestCase):          self.assertIn(ssl.HAS_SNI, {True, False})          self.assertIn(ssl.HAS_ECDH, {True, False}) +    def test_str_for_enums(self): +        # Make sure that the PROTOCOL_* constants have enum-like string +        # reprs. +        proto = ssl.PROTOCOL_SSLv23 +        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23') +        ctx = ssl.SSLContext(proto) +        self.assertIs(ctx.protocol, proto) +      def test_random(self):          v = ssl.RAND_status()          if support.verbose: @@ -298,10 +312,10 @@ class BasicSocketTests(unittest.TestCase):          # Version string as returned by {Open,Libre}SSL, the format might change          if "LibreSSL" in s:              self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)), -                            (s, t)) +                            (s, t, hex(n)))          else:              self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), -                            (s, t)) +                            (s, t, hex(n)))      @support.cpython_only      def test_refcycle(self): @@ -369,6 +383,8 @@ class BasicSocketTests(unittest.TestCase):              self.assertRaises(ssl.CertificateError,                                ssl.match_hostname, cert, hostname) +        # -- Hostname matching -- +          cert = {'subject': ((('commonName', 'example.com'),),)}          ok(cert, 'example.com')          ok(cert, 'ExAmple.cOm') @@ -454,6 +470,28 @@ class BasicSocketTests(unittest.TestCase):          # Only commonName is considered          fail(cert, 'California') +        # -- IPv4 matching -- +        cert = {'subject': ((('commonName', 'example.com'),),), +                'subjectAltName': (('DNS', 'example.com'), +                                   ('IP Address', '10.11.12.13'), +                                   ('IP Address', '14.15.16.17'))} +        ok(cert, '10.11.12.13') +        ok(cert, '14.15.16.17') +        fail(cert, '14.15.16.18') +        fail(cert, 'example.net') + +        # -- IPv6 matching -- +        cert = {'subject': ((('commonName', 'example.com'),),), +                'subjectAltName': (('DNS', 'example.com'), +                                   ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'), +                                   ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))} +        ok(cert, '2001::cafe') +        ok(cert, '2003::baba') +        fail(cert, '2003::bebe') +        fail(cert, 'example.net') + +        # -- Miscellaneous -- +          # Neither commonName nor subjectAltName          cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',                  'subject': ((('countryName', 'US'),), @@ -505,9 +543,14 @@ class BasicSocketTests(unittest.TestCase):      def test_unknown_channel_binding(self):          # should raise ValueError for unknown type          s = socket.socket(socket.AF_INET) -        with ssl.wrap_socket(s) as ss: +        s.bind(('127.0.0.1', 0)) +        s.listen() +        c = socket.socket(socket.AF_INET) +        c.connect(s.getsockname()) +        with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss:              with self.assertRaises(ValueError):                  ss.get_channel_binding("unknown-type") +        s.close()      @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,                           "'tls-unique' channel binding not available") @@ -648,6 +691,71 @@ class BasicSocketTests(unittest.TestCase):              ctx.wrap_socket(s)          self.assertEqual(str(cx.exception), "only stream sockets are supported") +    def cert_time_ok(self, timestring, timestamp): +        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp) + +    def cert_time_fail(self, timestring): +        with self.assertRaises(ValueError): +            ssl.cert_time_to_seconds(timestring) + +    @unittest.skipUnless(utc_offset(), +                         'local time needs to be different from UTC') +    def test_cert_time_to_seconds_timezone(self): +        # Issue #19940: ssl.cert_time_to_seconds() returns wrong +        #               results if local timezone is not UTC +        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0) +        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0) + +    def test_cert_time_to_seconds(self): +        timestring = "Jan  5 09:34:43 2018 GMT" +        ts = 1515144883.0 +        self.cert_time_ok(timestring, ts) +        # accept keyword parameter, assert its name +        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts) +        # accept both %e and %d (space or zero generated by strftime) +        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts) +        # case-insensitive +        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts) +        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds +        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT +        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone +        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day +        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month +        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour +        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute + +        newyear_ts = 1230768000.0 +        # leap seconds +        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts) +        # same timestamp +        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts) + +        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899) +        #  allow 60th second (even if it is not a leap second) +        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900) +        #  allow 2nd leap second for compatibility with time.strptime() +        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901) +        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds + +        # no special treatement for the special value: +        #   99991231235959Z (rfc 5280) +        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) + +    @support.run_with_locale('LC_ALL', '') +    def test_cert_time_to_seconds_locale(self): +        # `cert_time_to_seconds()` should be locale independent + +        def local_february_name(): +            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0)) + +        if local_february_name().lower() == 'feb': +            self.skipTest("locale-specific month name needs to be " +                          "different from C locale") + +        # locale-independent +        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0) +        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT") +  class ContextTests(unittest.TestCase): @@ -1157,7 +1265,7 @@ class SSLErrorTests(unittest.TestCase):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)          with socket.socket() as s:              s.bind(("127.0.0.1", 0)) -            s.listen(5) +            s.listen()              c = socket.socket()              c.connect(s.getsockname())              c.setblocking(False) @@ -1170,6 +1278,69 @@ class SSLErrorTests(unittest.TestCase):                  self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) +class MemoryBIOTests(unittest.TestCase): + +    def test_read_write(self): +        bio = ssl.MemoryBIO() +        bio.write(b'foo') +        self.assertEqual(bio.read(), b'foo') +        self.assertEqual(bio.read(), b'') +        bio.write(b'foo') +        bio.write(b'bar') +        self.assertEqual(bio.read(), b'foobar') +        self.assertEqual(bio.read(), b'') +        bio.write(b'baz') +        self.assertEqual(bio.read(2), b'ba') +        self.assertEqual(bio.read(1), b'z') +        self.assertEqual(bio.read(1), b'') + +    def test_eof(self): +        bio = ssl.MemoryBIO() +        self.assertFalse(bio.eof) +        self.assertEqual(bio.read(), b'') +        self.assertFalse(bio.eof) +        bio.write(b'foo') +        self.assertFalse(bio.eof) +        bio.write_eof() +        self.assertFalse(bio.eof) +        self.assertEqual(bio.read(2), b'fo') +        self.assertFalse(bio.eof) +        self.assertEqual(bio.read(1), b'o') +        self.assertTrue(bio.eof) +        self.assertEqual(bio.read(), b'') +        self.assertTrue(bio.eof) + +    def test_pending(self): +        bio = ssl.MemoryBIO() +        self.assertEqual(bio.pending, 0) +        bio.write(b'foo') +        self.assertEqual(bio.pending, 3) +        for i in range(3): +            bio.read(1) +            self.assertEqual(bio.pending, 3-i-1) +        for i in range(3): +            bio.write(b'x') +            self.assertEqual(bio.pending, i+1) +        bio.read() +        self.assertEqual(bio.pending, 0) + +    def test_buffer_types(self): +        bio = ssl.MemoryBIO() +        bio.write(b'foo') +        self.assertEqual(bio.read(), b'foo') +        bio.write(bytearray(b'bar')) +        self.assertEqual(bio.read(), b'bar') +        bio.write(memoryview(b'baz')) +        self.assertEqual(bio.read(), b'baz') + +    def test_error_types(self): +        bio = ssl.MemoryBIO() +        self.assertRaises(TypeError, bio.write, 'foo') +        self.assertRaises(TypeError, bio.write, None) +        self.assertRaises(TypeError, bio.write, True) +        self.assertRaises(TypeError, bio.write, 1) + +  class NetworkedTests(unittest.TestCase):      def test_connect(self): @@ -1397,14 +1568,12 @@ class NetworkedTests(unittest.TestCase):      def test_get_server_certificate(self):          def _test_get_server_certificate(host, port, cert=None):              with support.transient_internet(host): -                pem = ssl.get_server_certificate((host, port), -                                                 ssl.PROTOCOL_SSLv23) +                pem = ssl.get_server_certificate((host, port))                  if not pem:                      self.fail("No server certificate on %s:%s!" % (host, port))                  try:                      pem = ssl.get_server_certificate((host, port), -                                                     ssl.PROTOCOL_SSLv23,                                                       ca_certs=CERTFILE)                  except ssl.SSLError as x:                      #should fail @@ -1414,7 +1583,6 @@ class NetworkedTests(unittest.TestCase):                      self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))                  pem = ssl.get_server_certificate((host, port), -                                                 ssl.PROTOCOL_SSLv23,                                                   ca_certs=cert)                  if not pem:                      self.fail("No server certificate on %s:%s!" % (host, port)) @@ -1500,6 +1668,93 @@ class NetworkedTests(unittest.TestCase):                  self.assertIs(ss.context, ctx2)                  self.assertIs(ss._sslobj.context, ctx2) + +class NetworkedBIOTests(unittest.TestCase): + +    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): +        # A simple IO loop. Call func(*args) depending on the error we get +        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs. +        timeout = kwargs.get('timeout', 10) +        count = 0 +        while True: +            errno = None +            count += 1 +            try: +                ret = func(*args) +            except ssl.SSLError as e: +                # Note that we get a spurious -1/SSL_ERROR_SYSCALL for +                # non-blocking IO. The SSL_shutdown manpage hints at this. +                # It *should* be safe to just ignore SYS_ERROR_SYSCALL because +                # with a Memory BIO there's no syscalls (for IO at least). +                if e.errno not in (ssl.SSL_ERROR_WANT_READ, +                                   ssl.SSL_ERROR_WANT_WRITE, +                                   ssl.SSL_ERROR_SYSCALL): +                    raise +                errno = e.errno +            # Get any data from the outgoing BIO irrespective of any error, and +            # send it to the socket. +            buf = outgoing.read() +            sock.sendall(buf) +            # If there's no error, we're done. For WANT_READ, we need to get +            # data from the socket and put it in the incoming BIO. +            if errno is None: +                break +            elif errno == ssl.SSL_ERROR_WANT_READ: +                buf = sock.recv(32768) +                if buf: +                    incoming.write(buf) +                else: +                    incoming.write_eof() +        if support.verbose: +            sys.stdout.write("Needed %d calls to complete %s().\n" +                             % (count, func.__name__)) +        return ret + +    def test_handshake(self): +        with support.transient_internet("svn.python.org"): +            sock = socket.socket(socket.AF_INET) +            sock.connect(("svn.python.org", 443)) +            incoming = ssl.MemoryBIO() +            outgoing = ssl.MemoryBIO() +            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) +            ctx.verify_mode = ssl.CERT_REQUIRED +            ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) +            ctx.check_hostname = True +            sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org') +            self.assertIs(sslobj._sslobj.owner, sslobj) +            self.assertIsNone(sslobj.cipher()) +            self.assertIsNone(sslobj.shared_ciphers()) +            self.assertRaises(ValueError, sslobj.getpeercert) +            if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: +                self.assertIsNone(sslobj.get_channel_binding('tls-unique')) +            self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) +            self.assertTrue(sslobj.cipher()) +            self.assertIsNone(sslobj.shared_ciphers()) +            self.assertTrue(sslobj.getpeercert()) +            if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: +                self.assertTrue(sslobj.get_channel_binding('tls-unique')) +            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) +            self.assertRaises(ssl.SSLError, sslobj.write, b'foo') +            sock.close() + +    def test_read_write_data(self): +        with support.transient_internet("svn.python.org"): +            sock = socket.socket(socket.AF_INET) +            sock.connect(("svn.python.org", 443)) +            incoming = ssl.MemoryBIO() +            outgoing = ssl.MemoryBIO() +            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) +            ctx.verify_mode = ssl.CERT_NONE +            sslobj = ctx.wrap_bio(incoming, outgoing, False) +            self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) +            req = b'GET / HTTP/1.0\r\n\r\n' +            self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) +            buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) +            self.assertEqual(buf[:5], b'HTTP/') +            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) +            sock.close() + +  try:      import threading  except ImportError: @@ -1531,7 +1786,8 @@ else:                  try:                      self.sslconn = self.server.context.wrap_socket(                          self.sock, server_side=True) -                    self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) +                    self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) +                    self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())                  except (ssl.SSLError, ConnectionResetError) as e:                      # We treat ConnectionResetError as though it were an                      # SSLError - OpenSSL on Ubuntu abruptly closes the @@ -1548,6 +1804,7 @@ else:                      self.close()                      return False                  else: +                    self.server.shared_ciphers.append(self.sslconn.shared_ciphers())                      if self.server.context.verify_mode == ssl.CERT_REQUIRED:                          cert = self.sslconn.getpeercert()                          if support.verbose and self.server.chatty: @@ -1638,7 +1895,8 @@ else:          def __init__(self, certificate=None, ssl_version=None,                       certreqs=None, cacerts=None,                       chatty=True, connectionchatty=False, starttls_server=False, -                     npn_protocols=None, ciphers=None, context=None): +                     npn_protocols=None, alpn_protocols=None, +                     ciphers=None, context=None):              if context:                  self.context = context              else: @@ -1653,6 +1911,8 @@ else:                      self.context.load_cert_chain(certificate)                  if npn_protocols:                      self.context.set_npn_protocols(npn_protocols) +                if alpn_protocols: +                    self.context.set_alpn_protocols(alpn_protocols)                  if ciphers:                      self.context.set_ciphers(ciphers)              self.chatty = chatty @@ -1662,7 +1922,9 @@ else:              self.port = support.bind_port(self.sock)              self.flag = None              self.active = False -            self.selected_protocols = [] +            self.selected_npn_protocols = [] +            self.selected_alpn_protocols = [] +            self.shared_ciphers = []              self.conn_errors = []              threading.Thread.__init__(self)              self.daemon = True @@ -1682,7 +1944,7 @@ else:          def run(self):              self.sock.settimeout(0.05) -            self.sock.listen(5) +            self.sock.listen()              self.active = True              if self.flag:                  # signal an event @@ -1888,14 +2150,25 @@ else:                      'compression': s.compression(),                      'cipher': s.cipher(),                      'peercert': s.getpeercert(), -                    'client_npn_protocol': s.selected_npn_protocol() +                    'client_alpn_protocol': s.selected_alpn_protocol(), +                    'client_npn_protocol': s.selected_npn_protocol(), +                    'version': s.version(),                  })                  s.close() -            stats['server_npn_protocols'] = server.selected_protocols +            stats['server_alpn_protocols'] = server.selected_alpn_protocols +            stats['server_npn_protocols'] = server.selected_npn_protocols +            stats['server_shared_ciphers'] = server.shared_ciphers          return stats      def try_protocol_combo(server_protocol, client_protocol, expect_success,                             certsreqs=None, server_options=0, client_options=0): +        """ +        Try to SSL-connect using *client_protocol* to *server_protocol*. +        If *expect_success* is true, assert that the connection succeeds, +        if it's false, assert that the connection fails. +        Also, if *expect_success* is a string, assert that it is the protocol +        version actually used by the connection. +        """          if certsreqs is None:              certsreqs = ssl.CERT_NONE          certtype = { @@ -1925,8 +2198,8 @@ else:              ctx.load_cert_chain(CERTFILE)              ctx.load_verify_locations(CERTFILE)          try: -            server_params_test(client_context, server_context, -                               chatty=False, connectionchatty=False) +            stats = server_params_test(client_context, server_context, +                                       chatty=False, connectionchatty=False)          # Protocol mismatch can result in either an SSLError, or a          # "Connection reset by peer" error.          except ssl.SSLError: @@ -1941,6 +2214,10 @@ else:                      "Client protocol %s succeeded with server protocol %s!"                      % (ssl.get_protocol_name(client_protocol),                         ssl.get_protocol_name(server_protocol))) +            elif (expect_success is not True +                  and expect_success != stats['version']): +                raise AssertionError("version mismatch: expected %r, got %r" +                                     % (expect_success, stats['version']))      class ThreadedTests(unittest.TestCase): @@ -2108,7 +2385,7 @@ else:              # and sets Event `listener_gone` to let the main thread know              # the socket is gone.              def listener(): -                s.listen(5) +                s.listen()                  listener_ready.set()                  newsock, addr = s.accept()                  newsock.close() @@ -2173,19 +2450,19 @@ else:                              " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"                              % str(x))              if hasattr(ssl, 'PROTOCOL_SSLv3'): -                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) +                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3')              try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) -            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) +            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')              if hasattr(ssl, 'PROTOCOL_SSLv3'): -                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) +                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)              try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) -            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) +            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)              if hasattr(ssl, 'PROTOCOL_SSLv3'): -                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) +                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)              try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) -            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) +            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)              # Server with specific SSL options              if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2205,9 +2482,9 @@ else:              """Connecting to an SSLv3 server with various client options"""              if support.verbose:                  sys.stdout.write("\n") -            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) -            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) -            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) +            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') +            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) +            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)              if hasattr(ssl, 'PROTOCOL_SSLv2'):                  try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)              try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, @@ -2215,7 +2492,7 @@ else:              try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)              if no_sslv2_implies_sslv3_hello():                  # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs -                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True, +                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3',                                     client_options=ssl.OP_NO_SSLv2)          @skip_if_broken_ubuntu_ssl @@ -2223,9 +2500,9 @@ else:              """Connecting to a TLSv1 server with various client options"""              if support.verbose:                  sys.stdout.write("\n") -            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) -            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) -            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) +            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') +            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) +            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)              if hasattr(ssl, 'PROTOCOL_SSLv2'):                  try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)              if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2241,7 +2518,7 @@ else:                 Testing against older TLS versions."""              if support.verbose:                  sys.stdout.write("\n") -            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) +            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')              if hasattr(ssl, 'PROTOCOL_SSLv2'):                  try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)              if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2249,7 +2526,7 @@ else:              try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,                                 client_options=ssl.OP_NO_TLSv1_1) -            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) +            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')              try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)              try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) @@ -2262,7 +2539,7 @@ else:                 Testing against older TLS versions."""              if support.verbose:                  sys.stdout.write("\n") -            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, +            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',                                 server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,                                 client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)              if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2272,7 +2549,7 @@ else:              try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,                                 client_options=ssl.OP_NO_TLSv1_2) -            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) +            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')              try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)              try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)              try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2508,6 +2785,36 @@ else:                  s.write(b"over\n")                  s.close() +        def test_nonblocking_send(self): +            server = ThreadedEchoServer(CERTFILE, +                                        certreqs=ssl.CERT_NONE, +                                        ssl_version=ssl.PROTOCOL_TLSv1, +                                        cacerts=CERTFILE, +                                        chatty=True, +                                        connectionchatty=False) +            with server: +                s = ssl.wrap_socket(socket.socket(), +                                    server_side=False, +                                    certfile=CERTFILE, +                                    ca_certs=CERTFILE, +                                    cert_reqs=ssl.CERT_NONE, +                                    ssl_version=ssl.PROTOCOL_TLSv1) +                s.connect((HOST, server.port)) +                s.setblocking(False) + +                # If we keep sending data, at some point the buffers +                # will be full and the call will block +                buf = bytearray(8192) +                def fill_buffer(): +                    while True: +                        s.send(buf) +                self.assertRaises((ssl.SSLWantWriteError, +                                   ssl.SSLWantReadError), fill_buffer) + +                # Now read all the output and discard it +                s.setblocking(True) +                s.close() +          def test_handshake_timeout(self):              # Issue #5103: SSL handshake must respect the socket timeout              server = socket.socket(socket.AF_INET) @@ -2517,7 +2824,7 @@ else:              finish = False              def serve(): -                server.listen(5) +                server.listen()                  started.set()                  conns = []                  while not finish: @@ -2574,7 +2881,7 @@ else:              peer = None              def serve():                  nonlocal remote, peer -                server.listen(5) +                server.listen()                  # Block on the accept and wait on the connection to close.                  evt.set()                  remote, peer = server.accept() @@ -2624,6 +2931,21 @@ else:                          s.connect((HOST, server.port))              self.assertIn("no shared cipher", str(server.conn_errors[0])) +        def test_version_basic(self): +            """ +            Basic tests for SSLSocket.version(). +            More tests are done in the test_protocol_*() methods. +            """ +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            with ThreadedEchoServer(CERTFILE, +                                    ssl_version=ssl.PROTOCOL_TLSv1, +                                    chatty=False) as server: +                with context.wrap_socket(socket.socket()) as s: +                    self.assertIs(s.version(), None) +                    s.connect((HOST, server.port)) +                    self.assertEqual(s.version(), "TLSv1") +                self.assertIs(s.version(), None) +          @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")          def test_default_ecdh_curve(self):              # Issue #21015: elliptic curve-based Diffie Hellman key exchange @@ -2733,6 +3055,55 @@ else:              if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:                  self.fail("Non-DH cipher: " + cipher[0]) +        def test_selected_alpn_protocol(self): +            # selected_alpn_protocol() is None unless ALPN is used. +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            context.load_cert_chain(CERTFILE) +            stats = server_params_test(context, context, +                                       chatty=True, connectionchatty=True) +            self.assertIs(stats['client_alpn_protocol'], None) + +        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") +        def test_selected_alpn_protocol_if_server_uses_alpn(self): +            # selected_alpn_protocol() is None unless ALPN is used by the client. +            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            client_context.load_verify_locations(CERTFILE) +            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            server_context.load_cert_chain(CERTFILE) +            server_context.set_alpn_protocols(['foo', 'bar']) +            stats = server_params_test(client_context, server_context, +                                       chatty=True, connectionchatty=True) +            self.assertIs(stats['client_alpn_protocol'], None) + +        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") +        def test_alpn_protocols(self): +            server_protocols = ['foo', 'bar', 'milkshake'] +            protocol_tests = [ +                (['foo', 'bar'], 'foo'), +                (['bar', 'foo'], 'foo'), +                (['milkshake'], 'milkshake'), +                (['http/3.0', 'http/4.0'], None) +            ] +            for client_protocols, expected in protocol_tests: +                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +                server_context.load_cert_chain(CERTFILE) +                server_context.set_alpn_protocols(server_protocols) +                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +                client_context.load_cert_chain(CERTFILE) +                client_context.set_alpn_protocols(client_protocols) +                stats = server_params_test(client_context, server_context, +                                           chatty=True, connectionchatty=True) + +                msg = "failed trying %s (s) and %s (c).\n" \ +                      "was expecting %s, but got %%s from the %%s" \ +                          % (str(server_protocols), str(client_protocols), +                             str(expected)) +                client_result = stats['client_alpn_protocol'] +                self.assertEqual(client_result, expected, msg % (client_result, "client")) +                server_result = stats['server_alpn_protocols'][-1] \ +                    if len(stats['server_alpn_protocols']) else 'nothing' +                self.assertEqual(server_result, expected, msg % (server_result, "server")) +          def test_selected_npn_protocol(self):              # selected_npn_protocol() is None unless NPN is used              context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -2873,6 +3244,20 @@ else:              self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')              self.assertIn("TypeError", stderr.getvalue()) +        def test_shared_ciphers(self): +            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            server_context.load_cert_chain(SIGNED_CERTFILE) +            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            client_context.verify_mode = ssl.CERT_REQUIRED +            client_context.load_verify_locations(SIGNING_CA) +            client_context.set_ciphers("RC4") +            server_context.set_ciphers("AES:RC4") +            stats = server_params_test(client_context, server_context) +            ciphers = stats['server_shared_ciphers'][0] +            self.assertGreater(len(ciphers), 0) +            for name, tls_version, bits in ciphers: +                self.assertIn("RC4", name.split("-")) +          def test_read_write_after_close_raises_valuerror(self):              context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)              context.verify_mode = ssl.CERT_REQUIRED @@ -2888,6 +3273,23 @@ else:                  self.assertRaises(ValueError, s.read, 1024)                  self.assertRaises(ValueError, s.write, b'hello') +        def test_sendfile(self): +            TEST_DATA = b"x" * 512 +            with open(support.TESTFN, 'wb') as f: +                f.write(TEST_DATA) +            self.addCleanup(support.unlink, support.TESTFN) +            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) +            context.verify_mode = ssl.CERT_REQUIRED +            context.load_verify_locations(CERTFILE) +            context.load_cert_chain(CERTFILE) +            server = ThreadedEchoServer(context=context, chatty=False) +            with server: +                with context.wrap_socket(socket.socket()) as s: +                    s.connect((HOST, server.port)) +                    with open(support.TESTFN, 'rb') as file: +                        s.sendfile(file) +                        self.assertEqual(s.recv(1024), TEST_DATA) +  def test_main(verbose=False):      if support.verbose: @@ -2921,10 +3323,11 @@ def test_main(verbose=False):          if not os.path.exists(filename):              raise support.TestFailed("Can't read certificate file %r" % filename) -    tests = [ContextTests, BasicSocketTests, SSLErrorTests] +    tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests]      if support.is_resource_enabled('network'):          tests.append(NetworkedTests) +        tests.append(NetworkedBIOTests)      if _have_threads:          thread_info = support.threading_setup()  | 
