diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 371 |
1 files changed, 338 insertions, 33 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 2dea0c5..67eea98 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -86,6 +86,12 @@ def have_verify_flags(): # 0.9.8 or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) +def utc_offset(): #NOTE: ignore issues like #1647654 + # local time = utc time + utc offset + if time.daylight and time.localtime().tm_isdst > 0: + return -time.altzone # seconds + return -time.timezone + def asn1time(cert_time): # Some versions of OpenSSL ignore seconds, see #18207 # 0.9.8.i @@ -134,6 +140,14 @@ class BasicSocketTests(unittest.TestCase): self.assertIn(ssl.HAS_SNI, {True, False}) self.assertIn(ssl.HAS_ECDH, {True, False}) + def test_str_for_enums(self): + # Make sure that the PROTOCOL_* constants have enum-like string + # reprs. + proto = ssl.PROTOCOL_SSLv3 + self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv3') + ctx = ssl.SSLContext(proto) + self.assertIs(ctx.protocol, proto) + def test_random(self): v = ssl.RAND_status() if support.verbose: @@ -154,8 +168,9 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(ValueError, ssl.RAND_bytes, -5) self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5) - self.assertRaises(TypeError, ssl.RAND_egd, 1) - self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1) + if hasattr(ssl, 'RAND_egd'): + self.assertRaises(TypeError, ssl.RAND_egd, 1) + self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1) ssl.RAND_add("this is a random string", 75.0) @unittest.skipUnless(os.name == 'posix', 'requires posix') @@ -504,9 +519,14 @@ class BasicSocketTests(unittest.TestCase): def test_unknown_channel_binding(self): # should raise ValueError for unknown type s = socket.socket(socket.AF_INET) - with ssl.wrap_socket(s) as ss: + s.bind(('127.0.0.1', 0)) + s.listen() + c = socket.socket(socket.AF_INET) + c.connect(s.getsockname()) + with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss: with self.assertRaises(ValueError): ss.get_channel_binding("unknown-type") + s.close() @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") @@ -647,6 +667,71 @@ class BasicSocketTests(unittest.TestCase): ctx.wrap_socket(s) self.assertEqual(str(cx.exception), "only stream sockets are supported") + def cert_time_ok(self, timestring, timestamp): + self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp) + + def cert_time_fail(self, timestring): + with self.assertRaises(ValueError): + ssl.cert_time_to_seconds(timestring) + + @unittest.skipUnless(utc_offset(), + 'local time needs to be different from UTC') + def test_cert_time_to_seconds_timezone(self): + # Issue #19940: ssl.cert_time_to_seconds() returns wrong + # results if local timezone is not UTC + self.cert_time_ok("May 9 00:00:00 2007 GMT", 1178668800.0) + self.cert_time_ok("Jan 5 09:34:43 2018 GMT", 1515144883.0) + + def test_cert_time_to_seconds(self): + timestring = "Jan 5 09:34:43 2018 GMT" + ts = 1515144883.0 + self.cert_time_ok(timestring, ts) + # accept keyword parameter, assert its name + self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts) + # accept both %e and %d (space or zero generated by strftime) + self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts) + # case-insensitive + self.cert_time_ok("JaN 5 09:34:43 2018 GmT", ts) + self.cert_time_fail("Jan 5 09:34 2018 GMT") # no seconds + self.cert_time_fail("Jan 5 09:34:43 2018") # no GMT + self.cert_time_fail("Jan 5 09:34:43 2018 UTC") # not GMT timezone + self.cert_time_fail("Jan 35 09:34:43 2018 GMT") # invalid day + self.cert_time_fail("Jon 5 09:34:43 2018 GMT") # invalid month + self.cert_time_fail("Jan 5 24:00:00 2018 GMT") # invalid hour + self.cert_time_fail("Jan 5 09:60:43 2018 GMT") # invalid minute + + newyear_ts = 1230768000.0 + # leap seconds + self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts) + # same timestamp + self.cert_time_ok("Jan 1 00:00:00 2009 GMT", newyear_ts) + + self.cert_time_ok("Jan 5 09:34:59 2018 GMT", 1515144899) + # allow 60th second (even if it is not a leap second) + self.cert_time_ok("Jan 5 09:34:60 2018 GMT", 1515144900) + # allow 2nd leap second for compatibility with time.strptime() + self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901) + self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds + + # no special treatement for the special value: + # 99991231235959Z (rfc 5280) + self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) + + @support.run_with_locale('LC_ALL', '') + def test_cert_time_to_seconds_locale(self): + # `cert_time_to_seconds()` should be locale independent + + def local_february_name(): + return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0)) + + if local_february_name().lower() == 'feb': + self.skipTest("locale-specific month name needs to be " + "different from C locale") + + # locale-independent + self.cert_time_ok("Feb 9 00:00:00 2007 GMT", 1170979200.0) + self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT") + class ContextTests(unittest.TestCase): @@ -1155,7 +1240,7 @@ class SSLErrorTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) with socket.socket() as s: s.bind(("127.0.0.1", 0)) - s.listen(5) + s.listen() c = socket.socket() c.connect(s.getsockname()) c.setblocking(False) @@ -1168,6 +1253,69 @@ class SSLErrorTests(unittest.TestCase): self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) +class MemoryBIOTests(unittest.TestCase): + + def test_read_write(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + self.assertEqual(bio.read(), b'') + bio.write(b'foo') + bio.write(b'bar') + self.assertEqual(bio.read(), b'foobar') + self.assertEqual(bio.read(), b'') + bio.write(b'baz') + self.assertEqual(bio.read(2), b'ba') + self.assertEqual(bio.read(1), b'z') + self.assertEqual(bio.read(1), b'') + + def test_eof(self): + bio = ssl.MemoryBIO() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertFalse(bio.eof) + bio.write(b'foo') + self.assertFalse(bio.eof) + bio.write_eof() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(2), b'fo') + self.assertFalse(bio.eof) + self.assertEqual(bio.read(1), b'o') + self.assertTrue(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertTrue(bio.eof) + + def test_pending(self): + bio = ssl.MemoryBIO() + self.assertEqual(bio.pending, 0) + bio.write(b'foo') + self.assertEqual(bio.pending, 3) + for i in range(3): + bio.read(1) + self.assertEqual(bio.pending, 3-i-1) + for i in range(3): + bio.write(b'x') + self.assertEqual(bio.pending, i+1) + bio.read() + self.assertEqual(bio.pending, 0) + + def test_buffer_types(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + bio.write(bytearray(b'bar')) + self.assertEqual(bio.read(), b'bar') + bio.write(memoryview(b'baz')) + self.assertEqual(bio.read(), b'baz') + + def test_error_types(self): + bio = ssl.MemoryBIO() + self.assertRaises(TypeError, bio.write, 'foo') + self.assertRaises(TypeError, bio.write, None) + self.assertRaises(TypeError, bio.write, True) + self.assertRaises(TypeError, bio.write, 1) + + class NetworkedTests(unittest.TestCase): def test_connect(self): @@ -1395,14 +1543,12 @@ class NetworkedTests(unittest.TestCase): def test_get_server_certificate(self): def _test_get_server_certificate(host, port, cert=None): with support.transient_internet(host): - pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23) + pem = ssl.get_server_certificate((host, port)) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) try: pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=CERTFILE) except ssl.SSLError as x: #should fail @@ -1412,7 +1558,6 @@ class NetworkedTests(unittest.TestCase): self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=cert) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) @@ -1498,6 +1643,91 @@ class NetworkedTests(unittest.TestCase): self.assertIs(ss.context, ctx2) self.assertIs(ss._sslobj.context, ctx2) + +class NetworkedBIOTests(unittest.TestCase): + + def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): + # A simple IO loop. Call func(*args) depending on the error we get + # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs. + timeout = kwargs.get('timeout', 10) + count = 0 + while True: + errno = None + count += 1 + try: + ret = func(*args) + except ssl.SSLError as e: + # Note that we get a spurious -1/SSL_ERROR_SYSCALL for + # non-blocking IO. The SSL_shutdown manpage hints at this. + # It *should* be safe to just ignore SYS_ERROR_SYSCALL because + # with a Memory BIO there's no syscalls (for IO at least). + if e.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + errno = e.errno + # Get any data from the outgoing BIO irrespective of any error, and + # send it to the socket. + buf = outgoing.read() + sock.sendall(buf) + # If there's no error, we're done. For WANT_READ, we need to get + # data from the socket and put it in the incoming BIO. + if errno is None: + break + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = sock.recv(32768) + if buf: + incoming.write(buf) + else: + incoming.write_eof() + if support.verbose: + sys.stdout.write("Needed %d calls to complete %s().\n" + % (count, func.__name__)) + return ret + + def test_handshake(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + ctx.check_hostname = True + sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org') + self.assertIs(sslobj._sslobj.owner, sslobj) + self.assertIsNone(sslobj.cipher()) + self.assertRaises(ValueError, sslobj.getpeercert) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertIsNone(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + self.assertTrue(sslobj.cipher()) + self.assertTrue(sslobj.getpeercert()) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertTrue(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + self.assertRaises(ssl.SSLError, sslobj.write, b'foo') + sock.close() + + def test_read_write_data(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_NONE + sslobj = ctx.wrap_bio(incoming, outgoing, False) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + req = b'GET / HTTP/1.0\r\n\r\n' + self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) + buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) + self.assertEqual(buf[:5], b'HTTP/') + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + sock.close() + + try: import threading except ImportError: @@ -1680,7 +1910,7 @@ else: def run(self): self.sock.settimeout(0.05) - self.sock.listen(5) + self.sock.listen() self.active = True if self.flag: # signal an event @@ -1886,7 +2116,8 @@ else: 'compression': s.compression(), 'cipher': s.cipher(), 'peercert': s.getpeercert(), - 'client_npn_protocol': s.selected_npn_protocol() + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), }) s.close() stats['server_npn_protocols'] = server.selected_protocols @@ -1894,6 +2125,13 @@ else: def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ if certsreqs is None: certsreqs = ssl.CERT_NONE certtype = { @@ -1923,8 +2161,8 @@ else: ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: - server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: @@ -1939,6 +2177,10 @@ else: "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) class ThreadedTests(unittest.TestCase): @@ -2105,7 +2347,7 @@ else: # and sets Event `listener_gone` to let the main thread know # the socket is gone. def listener(): - s.listen(5) + s.listen() listener_ready.set() newsock, addr = s.accept() newsock.close() @@ -2169,19 +2411,19 @@ else: " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3') try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) # Server with specific SSL options if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2201,9 +2443,9 @@ else: """Connecting to an SSLv3 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, @@ -2211,7 +2453,7 @@ else: try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3', client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl @@ -2219,9 +2461,9 @@ else: """Connecting to a TLSv1 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2237,7 +2479,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2245,7 +2487,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) @@ -2258,7 +2500,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2268,7 +2510,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2504,6 +2746,36 @@ else: s.write(b"over\n") s.close() + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + s.setblocking(False) + + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() + def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout server = socket.socket(socket.AF_INET) @@ -2513,7 +2785,7 @@ else: finish = False def serve(): - server.listen(5) + server.listen() started.set() conns = [] while not finish: @@ -2570,7 +2842,7 @@ else: peer = None def serve(): nonlocal remote, peer - server.listen(5) + server.listen() # Block on the accept and wait on the connection to close. evt.set() remote, peer = server.accept() @@ -2620,6 +2892,21 @@ else: s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) + s.connect((HOST, server.port)) + self.assertEqual(s.version(), "TLSv1") + self.assertIs(s.version(), None) + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange @@ -2884,6 +3171,23 @@ else: self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + def test_main(verbose=False): if support.verbose: @@ -2917,10 +3221,11 @@ def test_main(verbose=False): if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) - tests = [ContextTests, BasicSocketTests, SSLErrorTests] + tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests] if support.is_resource_enabled('network'): tests.append(NetworkedTests) + tests.append(NetworkedBIOTests) if _have_threads: thread_info = support.threading_setup() |