diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 479 |
1 files changed, 442 insertions, 37 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index cdf3aed..9140fc1 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -85,6 +85,12 @@ def have_verify_flags(): # 0.9.8 or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) +def utc_offset(): #NOTE: ignore issues like #1647654 + # local time = utc time + utc offset + if time.daylight and time.localtime().tm_isdst > 0: + return -time.altzone # seconds + return -time.timezone + def asn1time(cert_time): # Some versions of OpenSSL ignore seconds, see #18207 # 0.9.8.i @@ -133,6 +139,14 @@ class BasicSocketTests(unittest.TestCase): self.assertIn(ssl.HAS_SNI, {True, False}) self.assertIn(ssl.HAS_ECDH, {True, False}) + def test_str_for_enums(self): + # Make sure that the PROTOCOL_* constants have enum-like string + # reprs. + proto = ssl.PROTOCOL_SSLv23 + self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23') + ctx = ssl.SSLContext(proto) + self.assertIs(ctx.protocol, proto) + def test_random(self): v = ssl.RAND_status() if support.verbose: @@ -157,6 +171,8 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(TypeError, ssl.RAND_egd, 1) self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1) ssl.RAND_add("this is a random string", 75.0) + ssl.RAND_add(b"this is a random bytes object", 75.0) + ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0) @unittest.skipUnless(os.name == 'posix', 'requires posix') def test_random_fork(self): @@ -297,10 +313,10 @@ class BasicSocketTests(unittest.TestCase): # Version string as returned by {Open,Libre}SSL, the format might change if "LibreSSL" in s: self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)), - (s, t)) + (s, t, hex(n))) else: self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), - (s, t)) + (s, t, hex(n))) @support.cpython_only def test_refcycle(self): @@ -368,6 +384,8 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(ssl.CertificateError, ssl.match_hostname, cert, hostname) + # -- Hostname matching -- + cert = {'subject': ((('commonName', 'example.com'),),)} ok(cert, 'example.com') ok(cert, 'ExAmple.cOm') @@ -453,6 +471,28 @@ class BasicSocketTests(unittest.TestCase): # Only commonName is considered fail(cert, 'California') + # -- IPv4 matching -- + cert = {'subject': ((('commonName', 'example.com'),),), + 'subjectAltName': (('DNS', 'example.com'), + ('IP Address', '10.11.12.13'), + ('IP Address', '14.15.16.17'))} + ok(cert, '10.11.12.13') + ok(cert, '14.15.16.17') + fail(cert, '14.15.16.18') + fail(cert, 'example.net') + + # -- IPv6 matching -- + cert = {'subject': ((('commonName', 'example.com'),),), + 'subjectAltName': (('DNS', 'example.com'), + ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'), + ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))} + ok(cert, '2001::cafe') + ok(cert, '2003::baba') + fail(cert, '2003::bebe') + fail(cert, 'example.net') + + # -- Miscellaneous -- + # Neither commonName nor subjectAltName cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT', 'subject': ((('countryName', 'US'),), @@ -504,9 +544,14 @@ class BasicSocketTests(unittest.TestCase): def test_unknown_channel_binding(self): # should raise ValueError for unknown type s = socket.socket(socket.AF_INET) - with ssl.wrap_socket(s) as ss: + s.bind(('127.0.0.1', 0)) + s.listen() + c = socket.socket(socket.AF_INET) + c.connect(s.getsockname()) + with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss: with self.assertRaises(ValueError): ss.get_channel_binding("unknown-type") + s.close() @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") @@ -647,6 +692,71 @@ class BasicSocketTests(unittest.TestCase): ctx.wrap_socket(s) self.assertEqual(str(cx.exception), "only stream sockets are supported") + def cert_time_ok(self, timestring, timestamp): + self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp) + + def cert_time_fail(self, timestring): + with self.assertRaises(ValueError): + ssl.cert_time_to_seconds(timestring) + + @unittest.skipUnless(utc_offset(), + 'local time needs to be different from UTC') + def test_cert_time_to_seconds_timezone(self): + # Issue #19940: ssl.cert_time_to_seconds() returns wrong + # results if local timezone is not UTC + self.cert_time_ok("May 9 00:00:00 2007 GMT", 1178668800.0) + self.cert_time_ok("Jan 5 09:34:43 2018 GMT", 1515144883.0) + + def test_cert_time_to_seconds(self): + timestring = "Jan 5 09:34:43 2018 GMT" + ts = 1515144883.0 + self.cert_time_ok(timestring, ts) + # accept keyword parameter, assert its name + self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts) + # accept both %e and %d (space or zero generated by strftime) + self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts) + # case-insensitive + self.cert_time_ok("JaN 5 09:34:43 2018 GmT", ts) + self.cert_time_fail("Jan 5 09:34 2018 GMT") # no seconds + self.cert_time_fail("Jan 5 09:34:43 2018") # no GMT + self.cert_time_fail("Jan 5 09:34:43 2018 UTC") # not GMT timezone + self.cert_time_fail("Jan 35 09:34:43 2018 GMT") # invalid day + self.cert_time_fail("Jon 5 09:34:43 2018 GMT") # invalid month + self.cert_time_fail("Jan 5 24:00:00 2018 GMT") # invalid hour + self.cert_time_fail("Jan 5 09:60:43 2018 GMT") # invalid minute + + newyear_ts = 1230768000.0 + # leap seconds + self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts) + # same timestamp + self.cert_time_ok("Jan 1 00:00:00 2009 GMT", newyear_ts) + + self.cert_time_ok("Jan 5 09:34:59 2018 GMT", 1515144899) + # allow 60th second (even if it is not a leap second) + self.cert_time_ok("Jan 5 09:34:60 2018 GMT", 1515144900) + # allow 2nd leap second for compatibility with time.strptime() + self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901) + self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds + + # no special treatement for the special value: + # 99991231235959Z (rfc 5280) + self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) + + @support.run_with_locale('LC_ALL', '') + def test_cert_time_to_seconds_locale(self): + # `cert_time_to_seconds()` should be locale independent + + def local_february_name(): + return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0)) + + if local_february_name().lower() == 'feb': + self.skipTest("locale-specific month name needs to be " + "different from C locale") + + # locale-independent + self.cert_time_ok("Feb 9 00:00:00 2007 GMT", 1170979200.0) + self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT") + class ContextTests(unittest.TestCase): @@ -1156,7 +1266,7 @@ class SSLErrorTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) with socket.socket() as s: s.bind(("127.0.0.1", 0)) - s.listen(5) + s.listen() c = socket.socket() c.connect(s.getsockname()) c.setblocking(False) @@ -1169,6 +1279,69 @@ class SSLErrorTests(unittest.TestCase): self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) +class MemoryBIOTests(unittest.TestCase): + + def test_read_write(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + self.assertEqual(bio.read(), b'') + bio.write(b'foo') + bio.write(b'bar') + self.assertEqual(bio.read(), b'foobar') + self.assertEqual(bio.read(), b'') + bio.write(b'baz') + self.assertEqual(bio.read(2), b'ba') + self.assertEqual(bio.read(1), b'z') + self.assertEqual(bio.read(1), b'') + + def test_eof(self): + bio = ssl.MemoryBIO() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertFalse(bio.eof) + bio.write(b'foo') + self.assertFalse(bio.eof) + bio.write_eof() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(2), b'fo') + self.assertFalse(bio.eof) + self.assertEqual(bio.read(1), b'o') + self.assertTrue(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertTrue(bio.eof) + + def test_pending(self): + bio = ssl.MemoryBIO() + self.assertEqual(bio.pending, 0) + bio.write(b'foo') + self.assertEqual(bio.pending, 3) + for i in range(3): + bio.read(1) + self.assertEqual(bio.pending, 3-i-1) + for i in range(3): + bio.write(b'x') + self.assertEqual(bio.pending, i+1) + bio.read() + self.assertEqual(bio.pending, 0) + + def test_buffer_types(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + bio.write(bytearray(b'bar')) + self.assertEqual(bio.read(), b'bar') + bio.write(memoryview(b'baz')) + self.assertEqual(bio.read(), b'baz') + + def test_error_types(self): + bio = ssl.MemoryBIO() + self.assertRaises(TypeError, bio.write, 'foo') + self.assertRaises(TypeError, bio.write, None) + self.assertRaises(TypeError, bio.write, True) + self.assertRaises(TypeError, bio.write, 1) + + class NetworkedTests(unittest.TestCase): def test_connect(self): @@ -1396,14 +1569,12 @@ class NetworkedTests(unittest.TestCase): def test_get_server_certificate(self): def _test_get_server_certificate(host, port, cert=None): with support.transient_internet(host): - pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23) + pem = ssl.get_server_certificate((host, port)) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) try: pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=CERTFILE) except ssl.SSLError as x: #should fail @@ -1413,7 +1584,6 @@ class NetworkedTests(unittest.TestCase): self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=cert) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) @@ -1499,6 +1669,93 @@ class NetworkedTests(unittest.TestCase): self.assertIs(ss.context, ctx2) self.assertIs(ss._sslobj.context, ctx2) + +class NetworkedBIOTests(unittest.TestCase): + + def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): + # A simple IO loop. Call func(*args) depending on the error we get + # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs. + timeout = kwargs.get('timeout', 10) + count = 0 + while True: + errno = None + count += 1 + try: + ret = func(*args) + except ssl.SSLError as e: + # Note that we get a spurious -1/SSL_ERROR_SYSCALL for + # non-blocking IO. The SSL_shutdown manpage hints at this. + # It *should* be safe to just ignore SYS_ERROR_SYSCALL because + # with a Memory BIO there's no syscalls (for IO at least). + if e.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + errno = e.errno + # Get any data from the outgoing BIO irrespective of any error, and + # send it to the socket. + buf = outgoing.read() + sock.sendall(buf) + # If there's no error, we're done. For WANT_READ, we need to get + # data from the socket and put it in the incoming BIO. + if errno is None: + break + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = sock.recv(32768) + if buf: + incoming.write(buf) + else: + incoming.write_eof() + if support.verbose: + sys.stdout.write("Needed %d calls to complete %s().\n" + % (count, func.__name__)) + return ret + + def test_handshake(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + ctx.check_hostname = True + sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org') + self.assertIs(sslobj._sslobj.owner, sslobj) + self.assertIsNone(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertRaises(ValueError, sslobj.getpeercert) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertIsNone(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + self.assertTrue(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertTrue(sslobj.getpeercert()) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertTrue(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + self.assertRaises(ssl.SSLError, sslobj.write, b'foo') + sock.close() + + def test_read_write_data(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_NONE + sslobj = ctx.wrap_bio(incoming, outgoing, False) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + req = b'GET / HTTP/1.0\r\n\r\n' + self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) + buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) + self.assertEqual(buf[:5], b'HTTP/') + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + sock.close() + + try: import threading except ImportError: @@ -1530,7 +1787,8 @@ else: try: self.sslconn = self.server.context.wrap_socket( self.sock, server_side=True) - self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) except (ssl.SSLError, ConnectionResetError) as e: # We treat ConnectionResetError as though it were an # SSLError - OpenSSL on Ubuntu abruptly closes the @@ -1547,6 +1805,7 @@ else: self.close() return False else: + self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) if self.server.context.verify_mode == ssl.CERT_REQUIRED: cert = self.sslconn.getpeercert() if support.verbose and self.server.chatty: @@ -1637,7 +1896,8 @@ else: def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - npn_protocols=None, ciphers=None, context=None): + npn_protocols=None, alpn_protocols=None, + ciphers=None, context=None): if context: self.context = context else: @@ -1652,6 +1912,8 @@ else: self.context.load_cert_chain(certificate) if npn_protocols: self.context.set_npn_protocols(npn_protocols) + if alpn_protocols: + self.context.set_alpn_protocols(alpn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.chatty = chatty @@ -1661,7 +1923,9 @@ else: self.port = support.bind_port(self.sock) self.flag = None self.active = False - self.selected_protocols = [] + self.selected_npn_protocols = [] + self.selected_alpn_protocols = [] + self.shared_ciphers = [] self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True @@ -1681,7 +1945,7 @@ else: def run(self): self.sock.settimeout(0.05) - self.sock.listen(5) + self.sock.listen() self.active = True if self.flag: # signal an event @@ -1887,14 +2151,25 @@ else: 'compression': s.compression(), 'cipher': s.cipher(), 'peercert': s.getpeercert(), - 'client_npn_protocol': s.selected_npn_protocol() + 'client_alpn_protocol': s.selected_alpn_protocol(), + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), }) s.close() - stats['server_npn_protocols'] = server.selected_protocols + stats['server_alpn_protocols'] = server.selected_alpn_protocols + stats['server_npn_protocols'] = server.selected_npn_protocols + stats['server_shared_ciphers'] = server.shared_ciphers return stats def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ if certsreqs is None: certsreqs = ssl.CERT_NONE certtype = { @@ -1924,8 +2199,8 @@ else: ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: - server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: @@ -1940,6 +2215,10 @@ else: "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) class ThreadedTests(unittest.TestCase): @@ -2107,7 +2386,7 @@ else: # and sets Event `listener_gone` to let the main thread know # the socket is gone. def listener(): - s.listen(5) + s.listen() listener_ready.set() newsock, addr = s.accept() newsock.close() @@ -2172,19 +2451,19 @@ else: " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3') try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) # Server with specific SSL options if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2204,9 +2483,9 @@ else: """Connecting to an SSLv3 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, @@ -2214,7 +2493,7 @@ else: try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3', client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl @@ -2222,9 +2501,9 @@ else: """Connecting to a TLSv1 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2240,7 +2519,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2248,7 +2527,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) @@ -2261,7 +2540,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2271,7 +2550,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2507,6 +2786,36 @@ else: s.write(b"over\n") s.close() + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + s.setblocking(False) + + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() + def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout server = socket.socket(socket.AF_INET) @@ -2516,7 +2825,7 @@ else: finish = False def serve(): - server.listen(5) + server.listen() started.set() conns = [] while not finish: @@ -2573,7 +2882,7 @@ else: peer = None def serve(): nonlocal remote, peer - server.listen(5) + server.listen() # Block on the accept and wait on the connection to close. evt.set() remote, peer = server.accept() @@ -2623,6 +2932,21 @@ else: s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) + s.connect((HOST, server.port)) + self.assertEqual(s.version(), "TLSv1") + self.assertIs(s.version(), None) + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange @@ -2732,6 +3056,55 @@ else: if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: self.fail("Non-DH cipher: " + cipher[0]) + def test_selected_alpn_protocol(self): + # selected_alpn_protocol() is None unless ALPN is used. + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol_if_server_uses_alpn(self): + # selected_alpn_protocol() is None unless ALPN is used by the client. + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_verify_locations(CERTFILE) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(['foo', 'bar']) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") + def test_alpn_protocols(self): + server_protocols = ['foo', 'bar', 'milkshake'] + protocol_tests = [ + (['foo', 'bar'], 'foo'), + (['bar', 'foo'], 'foo'), + (['milkshake'], 'milkshake'), + (['http/3.0', 'http/4.0'], None) + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_cert_chain(CERTFILE) + client_context.set_alpn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_alpn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_alpn_protocols'][-1] \ + if len(stats['server_alpn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + def test_selected_npn_protocol(self): # selected_npn_protocol() is None unless NPN is used context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -2872,6 +3245,20 @@ else: self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') self.assertIn("TypeError", stderr.getvalue()) + def test_shared_ciphers(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + client_context.set_ciphers("RC4") + server_context.set_ciphers("AES:RC4") + stats = server_params_test(client_context, server_context) + ciphers = stats['server_shared_ciphers'][0] + self.assertGreater(len(ciphers), 0) + for name, tls_version, bits in ciphers: + self.assertIn("RC4", name.split("-")) + def test_read_write_after_close_raises_valuerror(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context.verify_mode = ssl.CERT_REQUIRED @@ -2887,6 +3274,23 @@ else: self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + def test_main(verbose=False): if support.verbose: @@ -2920,10 +3324,11 @@ def test_main(verbose=False): if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) - tests = [ContextTests, BasicSocketTests, SSLErrorTests] + tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests] if support.is_resource_enabled('network'): tests.append(NetworkedTests) + tests.append(NetworkedBIOTests) if _have_threads: thread_info = support.threading_setup() |