diff options
author | Antoine Pitrou <pitrou@free.fr> | 2017-09-07 16:56:24 (GMT) |
---|---|---|
committer | Victor Stinner <victor.stinner@gmail.com> | 2017-09-07 16:56:24 (GMT) |
commit | a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344 (patch) | |
tree | 1c31738009bee903417cea928e705a112aea2392 /Lib/test/test_ssl.py | |
parent | 1f06a680de465be0c24a78ea3b610053955daa99 (diff) | |
download | cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.zip cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.gz cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.bz2 |
bpo-31370: Remove support for threads-less builds (#3385)
* Remove Setup.config
* Always define WITH_THREAD for compatibility.
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 3222 |
1 files changed, 1605 insertions, 1617 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 16cad9d..89b4609 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -12,6 +12,7 @@ import os import errno import pprint import urllib.request +import threading import traceback import asyncore import weakref @@ -20,12 +21,6 @@ import functools ssl = support.import_module("ssl") -try: - import threading -except ImportError: - _have_threads = False -else: - _have_threads = True PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST @@ -1468,7 +1463,6 @@ class MemoryBIOTests(unittest.TestCase): self.assertRaises(TypeError, bio.write, 1) -@unittest.skipUnless(_have_threads, "Needs threading module") class SimpleBackgroundTests(unittest.TestCase): """Tests that connect to a simple server running in the background""" @@ -1828,1744 +1822,1743 @@ def _test_get_server_certificate_fail(test, host, port): test.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) -if _have_threads: - from test.ssl_servers import make_https_server +from test.ssl_servers import make_https_server - class ThreadedEchoServer(threading.Thread): +class ThreadedEchoServer(threading.Thread): - class ConnectionHandler(threading.Thread): + class ConnectionHandler(threading.Thread): - """A mildly complicated class, because we want it to work both - with and without the SSL wrapper around the socket connection, so - that we can test the STARTTLS functionality.""" + """A mildly complicated class, because we want it to work both + with and without the SSL wrapper around the socket connection, so + that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): - self.server = server + def __init__(self, server, connsock, addr): + self.server = server + self.running = False + self.sock = connsock + self.addr = addr + self.sock.setblocking(1) + self.sslconn = None + threading.Thread.__init__(self) + self.daemon = True + + def wrap_conn(self): + try: + self.sslconn = self.server.context.wrap_socket( + self.sock, server_side=True) + self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) + except (ssl.SSLError, ConnectionResetError, OSError) as e: + # We treat ConnectionResetError as though it were an + # SSLError - OpenSSL on Ubuntu abruptly closes the + # connection when asked to use an unsupported protocol. + # + # OSError may occur with wrong protocols, e.g. both + # sides use PROTOCOL_TLS_SERVER. + # + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more discriminating. + # + # bpo-31323: Store the exception as string to prevent + # a reference leak: server -> conn_errors -> exception + # -> traceback -> self (ConnectionHandler) -> server + self.server.conn_errors.append(str(e)) + if self.server.chatty: + handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") self.running = False - self.sock = connsock - self.addr = addr - self.sock.setblocking(1) - self.sslconn = None - threading.Thread.__init__(self) - self.daemon = True - - def wrap_conn(self): - try: - self.sslconn = self.server.context.wrap_socket( - self.sock, server_side=True) - self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) - self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) - except (ssl.SSLError, ConnectionResetError, OSError) as e: - # We treat ConnectionResetError as though it were an - # SSLError - OpenSSL on Ubuntu abruptly closes the - # connection when asked to use an unsupported protocol. - # - # OSError may occur with wrong protocols, e.g. both - # sides use PROTOCOL_TLS_SERVER. - # - # XXX Various errors can have happened here, for example - # a mismatching protocol version, an invalid certificate, - # or a low-level bug. This should be made more discriminating. - # - # bpo-31323: Store the exception as string to prevent - # a reference leak: server -> conn_errors -> exception - # -> traceback -> self (ConnectionHandler) -> server - self.server.conn_errors.append(str(e)) - if self.server.chatty: - handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") - self.running = False - self.server.stop() - self.close() - return False - else: - self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) - if self.server.context.verify_mode == ssl.CERT_REQUIRED: - cert = self.sslconn.getpeercert() - if support.verbose and self.server.chatty: - sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") - cert_binary = self.sslconn.getpeercert(True) - if support.verbose and self.server.chatty: - sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") - cipher = self.sslconn.cipher() + self.server.stop() + self.close() + return False + else: + self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) + if self.server.context.verify_mode == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() if support.verbose and self.server.chatty: - sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") - sys.stdout.write(" server: selected protocol is now " - + str(self.sslconn.selected_npn_protocol()) + "\n") - return True - - def read(self): - if self.sslconn: - return self.sslconn.read() - else: - return self.sock.recv(1024) + sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + sys.stdout.write(" server: selected protocol is now " + + str(self.sslconn.selected_npn_protocol()) + "\n") + return True + + def read(self): + if self.sslconn: + return self.sslconn.read() + else: + return self.sock.recv(1024) - def write(self, bytes): - if self.sslconn: - return self.sslconn.write(bytes) - else: - return self.sock.send(bytes) + def write(self, bytes): + if self.sslconn: + return self.sslconn.write(bytes) + else: + return self.sock.send(bytes) - def close(self): - if self.sslconn: - self.sslconn.close() - else: - self.sock.close() + def close(self): + if self.sslconn: + self.sslconn.close() + else: + self.sock.close() - def run(self): - self.running = True - if not self.server.starttls_server: - if not self.wrap_conn(): - return - while self.running: - try: - msg = self.read() - stripped = msg.strip() - if not stripped: - # eof, so quit this handler - self.running = False - try: - self.sock = self.sslconn.unwrap() - except OSError: - # Many tests shut the TCP connection down - # without an SSL shutdown. This causes - # unwrap() to raise OSError with errno=0! - pass - else: - self.sslconn = None - self.close() - elif stripped == b'over': - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: client closed connection\n") - self.close() - return - elif (self.server.starttls_server and - stripped == b'STARTTLS'): - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") - self.write(b"OK\n") - if not self.wrap_conn(): - return - elif (self.server.starttls_server and self.sslconn - and stripped == b'ENDTLS'): - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") - self.write(b"OK\n") + def run(self): + self.running = True + if not self.server.starttls_server: + if not self.wrap_conn(): + return + while self.running: + try: + msg = self.read() + stripped = msg.strip() + if not stripped: + # eof, so quit this handler + self.running = False + try: self.sock = self.sslconn.unwrap() - self.sslconn = None - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: connection is now unencrypted...\n") - elif stripped == b'CB tls-unique': - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") - data = self.sslconn.get_channel_binding("tls-unique") - self.write(repr(data).encode("us-ascii") + b"\n") + except OSError: + # Many tests shut the TCP connection down + # without an SSL shutdown. This causes + # unwrap() to raise OSError with errno=0! + pass else: - if (support.verbose and - self.server.connectionchatty): - ctype = (self.sslconn and "encrypted") or "unencrypted" - sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" - % (msg, ctype, msg.lower(), ctype)) - self.write(msg.lower()) - except OSError: - if self.server.chatty: - handle_error("Test server failure:\n") + self.sslconn = None self.close() - self.running = False - # normally, we'd just stop here, but for the test - # harness, we want to stop the server - self.server.stop() - - def __init__(self, certificate=None, ssl_version=None, - certreqs=None, cacerts=None, - chatty=True, connectionchatty=False, starttls_server=False, - npn_protocols=None, alpn_protocols=None, - ciphers=None, context=None): - if context: - self.context = context - else: - self.context = ssl.SSLContext(ssl_version - if ssl_version is not None - else ssl.PROTOCOL_TLSv1) - self.context.verify_mode = (certreqs if certreqs is not None - else ssl.CERT_NONE) - if cacerts: - self.context.load_verify_locations(cacerts) - if certificate: - self.context.load_cert_chain(certificate) - if npn_protocols: - self.context.set_npn_protocols(npn_protocols) - if alpn_protocols: - self.context.set_alpn_protocols(alpn_protocols) - if ciphers: - self.context.set_ciphers(ciphers) - self.chatty = chatty - self.connectionchatty = connectionchatty - self.starttls_server = starttls_server - self.sock = socket.socket() - self.port = support.bind_port(self.sock) - self.flag = None - self.active = False - self.selected_npn_protocols = [] - self.selected_alpn_protocols = [] - self.shared_ciphers = [] - self.conn_errors = [] - threading.Thread.__init__(self) - self.daemon = True - - def __enter__(self): - self.start(threading.Event()) - self.flag.wait() - return self - - def __exit__(self, *args): - self.stop() - self.join() - - def start(self, flag=None): - self.flag = flag - threading.Thread.start(self) + elif stripped == b'over': + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: client closed connection\n") + self.close() + return + elif (self.server.starttls_server and + stripped == b'STARTTLS'): + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") + self.write(b"OK\n") + if not self.wrap_conn(): + return + elif (self.server.starttls_server and self.sslconn + and stripped == b'ENDTLS'): + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") + self.write(b"OK\n") + self.sock = self.sslconn.unwrap() + self.sslconn = None + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: connection is now unencrypted...\n") + elif stripped == b'CB tls-unique': + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") + data = self.sslconn.get_channel_binding("tls-unique") + self.write(repr(data).encode("us-ascii") + b"\n") + else: + if (support.verbose and + self.server.connectionchatty): + ctype = (self.sslconn and "encrypted") or "unencrypted" + sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" + % (msg, ctype, msg.lower(), ctype)) + self.write(msg.lower()) + except OSError: + if self.server.chatty: + handle_error("Test server failure:\n") + self.close() + self.running = False + # normally, we'd just stop here, but for the test + # harness, we want to stop the server + self.server.stop() - def run(self): - self.sock.settimeout(0.05) - self.sock.listen() - self.active = True - if self.flag: - # signal an event - self.flag.set() - while self.active: - try: - newconn, connaddr = self.sock.accept() - if support.verbose and self.chatty: - sys.stdout.write(' server: new connection from ' - + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) - handler.start() - handler.join() - except socket.timeout: - pass - except KeyboardInterrupt: - self.stop() - self.sock.close() + def __init__(self, certificate=None, ssl_version=None, + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + npn_protocols=None, alpn_protocols=None, + ciphers=None, context=None): + if context: + self.context = context + else: + self.context = ssl.SSLContext(ssl_version + if ssl_version is not None + else ssl.PROTOCOL_TLSv1) + self.context.verify_mode = (certreqs if certreqs is not None + else ssl.CERT_NONE) + if cacerts: + self.context.load_verify_locations(cacerts) + if certificate: + self.context.load_cert_chain(certificate) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) + if alpn_protocols: + self.context.set_alpn_protocols(alpn_protocols) + if ciphers: + self.context.set_ciphers(ciphers) + self.chatty = chatty + self.connectionchatty = connectionchatty + self.starttls_server = starttls_server + self.sock = socket.socket() + self.port = support.bind_port(self.sock) + self.flag = None + self.active = False + self.selected_npn_protocols = [] + self.selected_alpn_protocols = [] + self.shared_ciphers = [] + self.conn_errors = [] + threading.Thread.__init__(self) + self.daemon = True + + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self + + def __exit__(self, *args): + self.stop() + self.join() + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen() + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + newconn, connaddr = self.sock.accept() + if support.verbose and self.chatty: + sys.stdout.write(' server: new connection from ' + + repr(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn, connaddr) + handler.start() + handler.join() + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() - def stop(self): - self.active = False + def stop(self): + self.active = False - class AsyncoreEchoServer(threading.Thread): +class AsyncoreEchoServer(threading.Thread): - # this one's based on asyncore.dispatcher + # this one's based on asyncore.dispatcher - class EchoServer (asyncore.dispatcher): + class EchoServer (asyncore.dispatcher): - class ConnectionHandler(asyncore.dispatcher_with_send): + class ConnectionHandler(asyncore.dispatcher_with_send): - def __init__(self, conn, certfile): - self.socket = test_wrap_socket(conn, server_side=True, - certfile=certfile, - do_handshake_on_connect=False) - asyncore.dispatcher_with_send.__init__(self, self.socket) - self._ssl_accepting = True - self._do_ssl_handshake() + def __init__(self, conn, certfile): + self.socket = test_wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + asyncore.dispatcher_with_send.__init__(self, self.socket) + self._ssl_accepting = True + self._do_ssl_handshake() - def readable(self): - if isinstance(self.socket, ssl.SSLSocket): - while self.socket.pending() > 0: - self.handle_read_event() - return True + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True - def _do_ssl_handshake(self): - try: - self.socket.do_handshake() - except (ssl.SSLWantReadError, ssl.SSLWantWriteError): - return - except ssl.SSLEOFError: + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + return + except ssl.SSLEOFError: + return self.handle_close() + except ssl.SSLError: + raise + except OSError as err: + if err.args[0] == errno.ECONNABORTED: return self.handle_close() - except ssl.SSLError: - raise - except OSError as err: - if err.args[0] == errno.ECONNABORTED: - return self.handle_close() - else: - self._ssl_accepting = False - - def handle_read(self): - if self._ssl_accepting: - self._do_ssl_handshake() - else: - data = self.recv(1024) - if support.verbose: - sys.stdout.write(" server: read %s from client\n" % repr(data)) - if not data: - self.close() - else: - self.send(data.lower()) + else: + self._ssl_accepting = False - def handle_close(self): - self.close() + def handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + else: + data = self.recv(1024) if support.verbose: - sys.stdout.write(" server: closed connection %s\n" % self.socket) - - def handle_error(self): - raise - - def __init__(self, certfile): - self.certfile = certfile - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = support.bind_port(sock, '') - asyncore.dispatcher.__init__(self, sock) - self.listen(5) + sys.stdout.write(" server: read %s from client\n" % repr(data)) + if not data: + self.close() + else: + self.send(data.lower()) - def handle_accepted(self, sock_obj, addr): + def handle_close(self): + self.close() if support.verbose: - sys.stdout.write(" server: new connection from %s:%s\n" %addr) - self.ConnectionHandler(sock_obj, self.certfile) + sys.stdout.write(" server: closed connection %s\n" % self.socket) def handle_error(self): raise def __init__(self, certfile): - self.flag = None - self.active = False - self.server = self.EchoServer(certfile) - self.port = self.server.port - threading.Thread.__init__(self) - self.daemon = True + self.certfile = certfile + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = support.bind_port(sock, '') + asyncore.dispatcher.__init__(self, sock) + self.listen(5) - def __str__(self): - return "<%s %s>" % (self.__class__.__name__, self.server) + def handle_accepted(self, sock_obj, addr): + if support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) - def __enter__(self): - self.start(threading.Event()) - self.flag.wait() - return self + def handle_error(self): + raise - def __exit__(self, *args): - if support.verbose: - sys.stdout.write(" cleanup: stopping server.\n") - self.stop() - if support.verbose: - sys.stdout.write(" cleanup: joining server thread.\n") - self.join() - if support.verbose: - sys.stdout.write(" cleanup: successfully joined.\n") - # make sure that ConnectionHandler is removed from socket_map - asyncore.close_all(ignore_all=True) + def __init__(self, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(certfile) + self.port = self.server.port + threading.Thread.__init__(self) + self.daemon = True - def start (self, flag=None): - self.flag = flag - threading.Thread.start(self) + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) - def run(self): - self.active = True - if self.flag: - self.flag.set() - while self.active: - try: - asyncore.loop(1) - except: - pass + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self - def stop(self): - self.active = False - self.server.close() + def __exit__(self, *args): + if support.verbose: + sys.stdout.write(" cleanup: stopping server.\n") + self.stop() + if support.verbose: + sys.stdout.write(" cleanup: joining server thread.\n") + self.join() + if support.verbose: + sys.stdout.write(" cleanup: successfully joined.\n") + # make sure that ConnectionHandler is removed from socket_map + asyncore.close_all(ignore_all=True) + + def start (self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + try: + asyncore.loop(1) + except: + pass - def server_params_test(client_context, server_context, indata=b"FOO\n", - chatty=True, connectionchatty=False, sni_name=None, - session=None): - """ - Launch a server, connect a client to it and try various reads - and writes. - """ - stats = {} - server = ThreadedEchoServer(context=server_context, - chatty=chatty, - connectionchatty=False) - with server: - with client_context.wrap_socket(socket.socket(), - server_hostname=sni_name, session=session) as s: - s.connect((HOST, server.port)) - for arg in [indata, bytearray(indata), memoryview(indata)]: - if connectionchatty: - if support.verbose: - sys.stdout.write( - " client: sending %r...\n" % indata) - s.write(arg) - outdata = s.read() - if connectionchatty: - if support.verbose: - sys.stdout.write(" client: read %r\n" % outdata) - if outdata != indata.lower(): - raise AssertionError( - "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" - % (outdata[:20], len(outdata), - indata[:20].lower(), len(indata))) - s.write(b"over\n") + def stop(self): + self.active = False + self.server.close() + +def server_params_test(client_context, server_context, indata=b"FOO\n", + chatty=True, connectionchatty=False, sni_name=None, + session=None): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + stats = {} + server = ThreadedEchoServer(context=server_context, + chatty=chatty, + connectionchatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name, session=session) as s: + s.connect((HOST, server.port)) + for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: if support.verbose: - sys.stdout.write(" client: closing connection.\n") - stats.update({ - 'compression': s.compression(), - 'cipher': s.cipher(), - 'peercert': s.getpeercert(), - 'client_alpn_protocol': s.selected_alpn_protocol(), - 'client_npn_protocol': s.selected_npn_protocol(), - 'version': s.version(), - 'session_reused': s.session_reused, - 'session': s.session, - }) - s.close() - stats['server_alpn_protocols'] = server.selected_alpn_protocols - stats['server_npn_protocols'] = server.selected_npn_protocols - stats['server_shared_ciphers'] = server.shared_ciphers - return stats + sys.stdout.write( + " client: sending %r...\n" % indata) + s.write(arg) + outdata = s.read() + if connectionchatty: + if support.verbose: + sys.stdout.write(" client: read %r\n" % outdata) + if outdata != indata.lower(): + raise AssertionError( + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") + if connectionchatty: + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + stats.update({ + 'compression': s.compression(), + 'cipher': s.cipher(), + 'peercert': s.getpeercert(), + 'client_alpn_protocol': s.selected_alpn_protocol(), + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), + 'session_reused': s.session_reused, + 'session': s.session, + }) + s.close() + stats['server_alpn_protocols'] = server.selected_alpn_protocols + stats['server_npn_protocols'] = server.selected_npn_protocols + stats['server_shared_ciphers'] = server.shared_ciphers + return stats + +def try_protocol_combo(server_protocol, client_protocol, expect_success, + certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ + if certsreqs is None: + certsreqs = ssl.CERT_NONE + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] + if support.verbose: + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" + sys.stdout.write(formatstr % + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) + client_context = ssl.SSLContext(client_protocol) + client_context.options |= client_options + server_context = ssl.SSLContext(server_protocol) + server_context.options |= server_options + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). + if client_context.protocol == ssl.PROTOCOL_SSLv23: + client_context.set_ciphers("ALL") + + for ctx in (client_context, server_context): + ctx.verify_mode = certsreqs + ctx.load_cert_chain(CERTFILE) + ctx.load_verify_locations(CERTFILE) + try: + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: + if expect_success: + raise + except OSError as e: + if expect_success or e.errno != errno.ECONNRESET: + raise + else: + if not expect_success: + raise AssertionError( + "Client protocol %s succeeded with server protocol %s!" + % (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) - def try_protocol_combo(server_protocol, client_protocol, expect_success, - certsreqs=None, server_options=0, client_options=0): - """ - Try to SSL-connect using *client_protocol* to *server_protocol*. - If *expect_success* is true, assert that the connection succeeds, - if it's false, assert that the connection fails. - Also, if *expect_success* is a string, assert that it is the protocol - version actually used by the connection. - """ - if certsreqs is None: - certsreqs = ssl.CERT_NONE - certtype = { - ssl.CERT_NONE: "CERT_NONE", - ssl.CERT_OPTIONAL: "CERT_OPTIONAL", - ssl.CERT_REQUIRED: "CERT_REQUIRED", - }[certsreqs] - if support.verbose: - formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" - sys.stdout.write(formatstr % - (ssl.get_protocol_name(client_protocol), - ssl.get_protocol_name(server_protocol), - certtype)) - client_context = ssl.SSLContext(client_protocol) - client_context.options |= client_options - server_context = ssl.SSLContext(server_protocol) - server_context.options |= server_options - - # NOTE: we must enable "ALL" ciphers on the client, otherwise an - # SSLv23 client will send an SSLv3 hello (rather than SSLv2) - # starting from OpenSSL 1.0.0 (see issue #8322). - if client_context.protocol == ssl.PROTOCOL_SSLv23: - client_context.set_ciphers("ALL") - - for ctx in (client_context, server_context): - ctx.verify_mode = certsreqs - ctx.load_cert_chain(CERTFILE) - ctx.load_verify_locations(CERTFILE) - try: - stats = server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) - # Protocol mismatch can result in either an SSLError, or a - # "Connection reset by peer" error. - except ssl.SSLError: - if expect_success: - raise - except OSError as e: - if expect_success or e.errno != errno.ECONNRESET: - raise - else: - if not expect_success: - raise AssertionError( - "Client protocol %s succeeded with server protocol %s!" - % (ssl.get_protocol_name(client_protocol), - ssl.get_protocol_name(server_protocol))) - elif (expect_success is not True - and expect_success != stats['version']): - raise AssertionError("version mismatch: expected %r, got %r" - % (expect_success, stats['version'])) - - - class ThreadedTests(unittest.TestCase): - - @skip_if_broken_ubuntu_ssl - def test_echo(self): - """Basic test of an SSL client connecting to a server""" - if support.verbose: - sys.stdout.write("\n") - for protocol in PROTOCOLS: - if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: - continue - with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - client_context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - # server_context.load_verify_locations(SIGNING_CA) - server_context.load_cert_chain(SIGNED_CERTFILE2) +class ThreadedTests(unittest.TestCase): - with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): - server_params_test(client_context=client_context, - server_context=server_context, + @skip_if_broken_ubuntu_ssl + def test_echo(self): + """Basic test of an SSL client connecting to a server""" + if support.verbose: + sys.stdout.write("\n") + for protocol in PROTOCOLS: + if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: + continue + with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): + context = ssl.SSLContext(protocol) + context.load_cert_chain(CERTFILE) + server_params_test(context, context, + chatty=True, connectionchatty=True) + + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_verify_locations(SIGNING_CA) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # server_context.load_verify_locations(SIGNING_CA) + server_context.load_cert_chain(SIGNED_CERTFILE2) + + with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): + server_params_test(client_context=client_context, + server_context=server_context, + chatty=True, connectionchatty=True, + sni_name='fakehostname') + + client_context.check_hostname = False + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=client_context, chatty=True, connectionchatty=True, sni_name='fakehostname') + self.assertIn('called a function you should not call', + str(e.exception)) - client_context.check_hostname = False - with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=client_context, - chatty=True, connectionchatty=True, - sni_name='fakehostname') - self.assertIn('called a function you should not call', - str(e.exception)) - - with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=server_context, - chatty=True, connectionchatty=True) - self.assertIn('called a function you should not call', - str(e.exception)) + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=server_context, + chatty=True, connectionchatty=True) + self.assertIn('called a function you should not call', + str(e.exception)) - with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=client_context, - chatty=True, connectionchatty=True) - self.assertIn('called a function you should not call', - str(e.exception)) + with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=client_context, + chatty=True, connectionchatty=True) + self.assertIn('called a function you should not call', + str(e.exception)) - def test_getpeercert(self): + def test_getpeercert(self): + if support.verbose: + sys.stdout.write("\n") + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + s = context.wrap_socket(socket.socket(), + do_handshake_on_connect=False) + s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() if support.verbose: - sys.stdout.write("\n") - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - with server: - s = context.wrap_socket(socket.socket(), - do_handshake_on_connect=False) + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + self.assertIn('notBefore', cert) + self.assertIn('notAfter', cert) + before = ssl.cert_time_to_seconds(cert['notBefore']) + after = ssl.cert_time_to_seconds(cert['notAfter']) + self.assertLess(before, after) + s.close() + + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_crl_check(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(SIGNING_CA) + tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) + self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) + + # VERIFY_DEFAULT should pass + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - # getpeercert() raise ValueError while the handshake isn't - # done. - with self.assertRaises(ValueError): - s.getpeercert() - s.do_handshake() cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") - cipher = s.cipher() - if support.verbose: - sys.stdout.write(pprint.pformat(cert) + '\n') - sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if 'subject' not in cert: - self.fail("No subject field in certificate: %s." % - pprint.pformat(cert)) - if ((('organizationName', 'Python Software Foundation'),) - not in cert['subject']): - self.fail( - "Missing or invalid 'organizationName' field in certificate subject; " - "should be 'Python Software Foundation'.") - self.assertIn('notBefore', cert) - self.assertIn('notAfter', cert) - before = ssl.cert_time_to_seconds(cert['notBefore']) - after = ssl.cert_time_to_seconds(cert['notAfter']) - self.assertLess(before, after) - s.close() - @unittest.skipUnless(have_verify_flags(), - "verify_flags need OpenSSL > 0.9.8") - def test_crl_check(self): - if support.verbose: - sys.stdout.write("\n") + # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails + context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(SIGNING_CA) - tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) - self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) - - # VERIFY_DEFAULT should pass - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaisesRegex(ssl.SSLError, + "certificate verify failed"): s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails - context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + # now load a CRL file. The CRL file is signed by the CA. + context.load_verify_locations(CRLFILE) - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: - with self.assertRaisesRegex(ssl.SSLError, - "certificate verify failed"): - s.connect((HOST, server.port)) + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") - # now load a CRL file. The CRL file is signed by the CA. - context.load_verify_locations(CRLFILE) + def test_check_hostname(self): + if support.verbose: + sys.stdout.write("\n") - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) - def test_check_hostname(self): - if support.verbose: - sys.stdout.write("\n") + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.check_hostname = True - context.load_verify_locations(SIGNING_CA) - - # correct hostname should verify - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket(), - server_hostname="localhost") as s: + # correct hostname should verify + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="localhost") as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # incorrect hostname should raise an exception + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex(ssl.CertificateError, + "hostname 'invalid' doesn't match 'localhost'"): s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - - # incorrect hostname should raise an exception - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket(), - server_hostname="invalid") as s: - with self.assertRaisesRegex(ssl.CertificateError, - "hostname 'invalid' doesn't match 'localhost'"): - s.connect((HOST, server.port)) - - # missing server_hostname arg should cause an exception, too - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with socket.socket() as s: - with self.assertRaisesRegex(ValueError, - "check_hostname requires server_hostname"): - context.wrap_socket(s) - - def test_wrong_cert(self): - """Connecting when the server rejects the client's certificate - - Launch a server with CERT_REQUIRED, and check that trying to - connect to it with a wrong client certificate fails. - """ - certfile = os.path.join(os.path.dirname(__file__) or os.curdir, - "wrongcert.pem") - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_REQUIRED, - cacerts=CERTFILE, chatty=False, - connectionchatty=False) - with server, \ - socket.socket() as sock, \ - test_wrap_socket(sock, - certfile=certfile, - ssl_version=ssl.PROTOCOL_TLSv1) as s: + + # missing server_hostname arg should cause an exception, too + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + context.wrap_socket(s) + + def test_wrong_cert(self): + """Connecting when the server rejects the client's certificate + + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with a wrong client certificate fails. + """ + certfile = os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem") + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_REQUIRED, + cacerts=CERTFILE, chatty=False, + connectionchatty=False) + with server, \ + socket.socket() as sock, \ + test_wrap_socket(sock, + certfile=certfile, + ssl_version=ssl.PROTOCOL_TLSv1) as s: + try: + # Expect either an SSL error about the server rejecting + # the connection, or a low-level connection reset (which + # sometimes happens on Windows) + s.connect((HOST, server.port)) + except ssl.SSLError as e: + if support.verbose: + sys.stdout.write("\nSSLError is %r\n" % e) + except OSError as e: + if e.errno != errno.ECONNRESET: + raise + if support.verbose: + sys.stdout.write("\nsocket.error is %r\n" % e) + else: + self.fail("Use of invalid cert should have failed!") + + def test_rude_shutdown(self): + """A brutal shutdown of an SSL server should raise an OSError + in the client when attempting handshake. + """ + listener_ready = threading.Event() + listener_gone = threading.Event() + + s = socket.socket() + port = support.bind_port(s, HOST) + + # `listener` runs in a thread. It sits in an accept() until + # the main thread connects. Then it rudely closes the socket, + # and sets Event `listener_gone` to let the main thread know + # the socket is gone. + def listener(): + s.listen() + listener_ready.set() + newsock, addr = s.accept() + newsock.close() + s.close() + listener_gone.set() + + def connector(): + listener_ready.wait() + with socket.socket() as c: + c.connect((HOST, port)) + listener_gone.wait() try: - # Expect either an SSL error about the server rejecting - # the connection, or a low-level connection reset (which - # sometimes happens on Windows) - s.connect((HOST, server.port)) - except ssl.SSLError as e: - if support.verbose: - sys.stdout.write("\nSSLError is %r\n" % e) - except OSError as e: - if e.errno != errno.ECONNRESET: - raise - if support.verbose: - sys.stdout.write("\nsocket.error is %r\n" % e) + ssl_sock = test_wrap_socket(c) + except OSError: + pass else: - self.fail("Use of invalid cert should have failed!") + self.fail('connecting to closed SSL socket should have failed') - def test_rude_shutdown(self): - """A brutal shutdown of an SSL server should raise an OSError - in the client when attempting handshake. - """ - listener_ready = threading.Event() - listener_gone = threading.Event() + t = threading.Thread(target=listener) + t.start() + try: + connector() + finally: + t.join() - s = socket.socket() - port = support.bind_port(s, HOST) - - # `listener` runs in a thread. It sits in an accept() until - # the main thread connects. Then it rudely closes the socket, - # and sets Event `listener_gone` to let the main thread know - # the socket is gone. - def listener(): - s.listen() - listener_ready.set() - newsock, addr = s.accept() - newsock.close() - s.close() - listener_gone.set() - - def connector(): - listener_ready.wait() - with socket.socket() as c: - c.connect((HOST, port)) - listener_gone.wait() - try: - ssl_sock = test_wrap_socket(c) - except OSError: - pass - else: - self.fail('connecting to closed SSL socket should have failed') + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), + "OpenSSL is compiled without SSLv2 support") + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + # SSLv23 client with specific SSL options + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv2) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1) - t = threading.Thread(target=listener) - t.start() + @skip_if_broken_ubuntu_ssl + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + if hasattr(ssl, 'PROTOCOL_SSLv2'): try: - connector() - finally: - t.join() + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) + except OSError as x: + # this fails on some older versions of OpenSSL (0.9.7l, for instance) + if support.verbose: + sys.stdout.write( + " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" + % str(x)) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + + # Server with specific SSL options + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, + server_options=ssl.OP_NO_SSLv3) + # Will choose TLSv1 + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, + server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, + server_options=ssl.OP_NO_TLSv1) + + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'), + "OpenSSL is compiled without SSLv3 support") + def test_protocol_sslv3(self): + """Connecting to an SSLv3 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, + False, client_options=ssl.OP_NO_SSLv2) + + @skip_if_broken_ubuntu_ssl + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1) + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") + def test_protocol_tlsv1_1(self): + """Connecting to a TLSv1.1 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_1) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), - "OpenSSL is compiled without SSLv2 support") - def test_protocol_sslv2(self): - """Connecting to an SSLv2 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) - # SSLv23 client with specific SSL options - if no_sslv2_implies_sslv3_hello(): - # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv2) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1) - @skip_if_broken_ubuntu_ssl - def test_protocol_sslv23(self): - """Connecting to an SSLv23 server with various client options""" + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), + "TLS version 1.2 not supported.") + def test_protocol_tlsv1_2(self): + """Connecting to a TLSv1.2 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', + server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, + client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_2) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") + + server = ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + starttls_server=True, + chatty=True, + connectionchatty=True) + wrapped = False + with server: + s = socket.socket() + s.setblocking(1) + s.connect((HOST, server.port)) if support.verbose: sys.stdout.write("\n") - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try: - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except OSError as x: - # this fails on some older versions of OpenSSL (0.9.7l, for instance) + for indata in msgs: + if support.verbose: + sys.stdout.write( + " client: sending %r...\n" % indata) + if wrapped: + conn.write(indata) + outdata = conn.read() + else: + s.send(indata) + outdata = s.recv(1024) + msg = outdata.strip().lower() + if indata == b"STARTTLS" and msg.startswith(b"ok"): + # STARTTLS ok, switch to secure mode if support.verbose: sys.stdout.write( - " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" - % str(x)) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') - - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) - - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) - - # Server with specific SSL options - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, - server_options=ssl.OP_NO_SSLv3) - # Will choose TLSv1 - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, - server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, - server_options=ssl.OP_NO_TLSv1) - - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'), - "OpenSSL is compiled without SSLv3 support") - def test_protocol_sslv3(self): - """Connecting to an SSLv3 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) - if no_sslv2_implies_sslv3_hello(): - # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, - False, client_options=ssl.OP_NO_SSLv2) - - @skip_if_broken_ubuntu_ssl - def test_protocol_tlsv1(self): - """Connecting to a TLSv1 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1) - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), - "TLS version 1.1 not supported.") - def test_protocol_tlsv1_1(self): - """Connecting to a TLSv1.1 server with various client options. - Testing against older TLS versions.""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1_1) - - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) - - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), - "TLS version 1.2 not supported.") - def test_protocol_tlsv1_2(self): - """Connecting to a TLSv1.2 server with various client options. - Testing against older TLS versions.""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', - server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, - client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1_2) - - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) - - def test_starttls(self): - """Switching from clear text to encrypted and back again.""" - msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") - - server = ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, - starttls_server=True, - chatty=True, - connectionchatty=True) - wrapped = False - with server: - s = socket.socket() - s.setblocking(1) - s.connect((HOST, server.port)) - if support.verbose: - sys.stdout.write("\n") - for indata in msgs: + " client: read %r from server, starting TLS...\n" + % msg) + conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + wrapped = True + elif indata == b"ENDTLS" and msg.startswith(b"ok"): + # ENDTLS ok, switch back to clear text if support.verbose: sys.stdout.write( - " client: sending %r...\n" % indata) - if wrapped: - conn.write(indata) - outdata = conn.read() - else: - s.send(indata) - outdata = s.recv(1024) - msg = outdata.strip().lower() - if indata == b"STARTTLS" and msg.startswith(b"ok"): - # STARTTLS ok, switch to secure mode - if support.verbose: - sys.stdout.write( - " client: read %r from server, starting TLS...\n" - % msg) - conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) - wrapped = True - elif indata == b"ENDTLS" and msg.startswith(b"ok"): - # ENDTLS ok, switch back to clear text - if support.verbose: - sys.stdout.write( - " client: read %r from server, ending TLS...\n" - % msg) - s = conn.unwrap() - wrapped = False - else: - if support.verbose: - sys.stdout.write( - " client: read %r from server\n" % msg) - if support.verbose: - sys.stdout.write(" client: closing connection.\n") - if wrapped: - conn.write(b"over\n") + " client: read %r from server, ending TLS...\n" + % msg) + s = conn.unwrap() + wrapped = False else: - s.send(b"over\n") - if wrapped: - conn.close() - else: - s.close() - - def test_socketserver(self): - """Using socketserver to create and manage SSL connections.""" - server = make_https_server(self, certfile=CERTFILE) - # try to connect - if support.verbose: - sys.stdout.write('\n') - with open(CERTFILE, 'rb') as f: - d1 = f.read() - d2 = '' - # now fetch the same data from the HTTPS server - url = 'https://localhost:%d/%s' % ( - server.port, os.path.split(CERTFILE)[1]) - context = ssl.create_default_context(cafile=CERTFILE) - f = urllib.request.urlopen(url, context=context) - try: - dlen = f.info().get("content-length") - if dlen and (int(dlen) > 0): - d2 = f.read(int(dlen)) if support.verbose: sys.stdout.write( - " client: read %d bytes from remote server '%s'\n" - % (len(d2), server)) - finally: - f.close() - self.assertEqual(d1, d2) - - def test_asyncore_server(self): - """Check the example asyncore integration.""" + " client: read %r from server\n" % msg) if support.verbose: - sys.stdout.write("\n") + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write(b"over\n") + else: + s.send(b"over\n") + if wrapped: + conn.close() + else: + s.close() - indata = b"FOO\n" - server = AsyncoreEchoServer(CERTFILE) - with server: - s = test_wrap_socket(socket.socket()) - s.connect(('127.0.0.1', server.port)) + def test_socketserver(self): + """Using socketserver to create and manage SSL connections.""" + server = make_https_server(self, certfile=CERTFILE) + # try to connect + if support.verbose: + sys.stdout.write('\n') + with open(CERTFILE, 'rb') as f: + d1 = f.read() + d2 = '' + # now fetch the same data from the HTTPS server + url = 'https://localhost:%d/%s' % ( + server.port, os.path.split(CERTFILE)[1]) + context = ssl.create_default_context(cafile=CERTFILE) + f = urllib.request.urlopen(url, context=context) + try: + dlen = f.info().get("content-length") + if dlen and (int(dlen) > 0): + d2 = f.read(int(dlen)) if support.verbose: sys.stdout.write( - " client: sending %r...\n" % indata) - s.write(indata) - outdata = s.read() - if support.verbose: - sys.stdout.write(" client: read %r\n" % outdata) - if outdata != indata.lower(): - self.fail( - "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" - % (outdata[:20], len(outdata), - indata[:20].lower(), len(indata))) - s.write(b"over\n") - if support.verbose: - sys.stdout.write(" client: closing connection.\n") - s.close() - if support.verbose: - sys.stdout.write(" client: connection closed.\n") + " client: read %d bytes from remote server '%s'\n" + % (len(d2), server)) + finally: + f.close() + self.assertEqual(d1, d2) - def test_recv_send(self): - """Test recv(), send() and friends.""" + def test_asyncore_server(self): + """Check the example asyncore integration.""" + if support.verbose: + sys.stdout.write("\n") + + indata = b"FOO\n" + server = AsyncoreEchoServer(CERTFILE) + with server: + s = test_wrap_socket(socket.socket()) + s.connect(('127.0.0.1', server.port)) if support.verbose: - sys.stdout.write("\n") + sys.stdout.write( + " client: sending %r...\n" % indata) + s.write(indata) + outdata = s.read() + if support.verbose: + sys.stdout.write(" client: read %r\n" % outdata) + if outdata != indata.lower(): + self.fail( + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + if support.verbose: + sys.stdout.write(" client: connection closed.\n") - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - # helper methods for standardising recv* method signatures - def _recv_into(): - b = bytearray(b"\0"*100) - count = s.recv_into(b) - return b[:count] - - def _recvfrom_into(): - b = bytearray(b"\0"*100) - count, addr = s.recvfrom_into(b) - return b[:count] - - # (name, method, expect success?, *args, return value func) - send_methods = [ - ('send', s.send, True, [], len), - ('sendto', s.sendto, False, ["some.address"], len), - ('sendall', s.sendall, True, [], lambda x: None), - ] - # (name, method, whether to expect success, *args) - recv_methods = [ - ('recv', s.recv, True, []), - ('recvfrom', s.recvfrom, False, ["some.address"]), - ('recv_into', _recv_into, True, []), - ('recvfrom_into', _recvfrom_into, False, []), - ] - data_prefix = "PREFIX_" - - for (meth_name, send_meth, expect_success, args, - ret_val_meth) in send_methods: - indata = (data_prefix + meth_name).encode('ascii') - try: - ret = send_meth(indata, *args) - msg = "sending with {}".format(meth_name) - self.assertEqual(ret, ret_val_meth(indata), msg=msg) - outdata = s.read() - if outdata != indata.lower(): - self.fail( - "While sending with <<{name:s}>> bad data " - "<<{outdata:r}>> ({nout:d}) received; " - "expected <<{indata:r}>> ({nin:d})\n".format( - name=meth_name, outdata=outdata[:20], - nout=len(outdata), - indata=indata[:20], nin=len(indata) - ) - ) - except ValueError as e: - if expect_success: - self.fail( - "Failed to send with method <<{name:s}>>; " - "expected to succeed.\n".format(name=meth_name) + def test_recv_send(self): + """Test recv(), send() and friends.""" + if support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # helper methods for standardising recv* method signatures + def _recv_into(): + b = bytearray(b"\0"*100) + count = s.recv_into(b) + return b[:count] + + def _recvfrom_into(): + b = bytearray(b"\0"*100) + count, addr = s.recvfrom_into(b) + return b[:count] + + # (name, method, expect success?, *args, return value func) + send_methods = [ + ('send', s.send, True, [], len), + ('sendto', s.sendto, False, ["some.address"], len), + ('sendall', s.sendall, True, [], lambda x: None), + ] + # (name, method, whether to expect success, *args) + recv_methods = [ + ('recv', s.recv, True, []), + ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recv_into', _recv_into, True, []), + ('recvfrom_into', _recvfrom_into, False, []), + ] + data_prefix = "PREFIX_" + + for (meth_name, send_meth, expect_success, args, + ret_val_meth) in send_methods: + indata = (data_prefix + meth_name).encode('ascii') + try: + ret = send_meth(indata, *args) + msg = "sending with {}".format(meth_name) + self.assertEqual(ret, ret_val_meth(indata), msg=msg) + outdata = s.read() + if outdata != indata.lower(): + self.fail( + "While sending with <<{name:s}>> bad data " + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], + nout=len(outdata), + indata=indata[:20], nin=len(indata) ) - if not str(e).startswith(meth_name): - self.fail( - "Method <<{name:s}>> failed with unexpected " - "exception message: {exp:s}\n".format( - name=meth_name, exp=e - ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to send with method <<{name:s}>>; " + "expected to succeed.\n".format(name=meth_name) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<{name:s}>> failed with unexpected " + "exception message: {exp:s}\n".format( + name=meth_name, exp=e ) + ) - for meth_name, recv_meth, expect_success, args in recv_methods: - indata = (data_prefix + meth_name).encode('ascii') - try: - s.send(indata) - outdata = recv_meth(*args) - if outdata != indata.lower(): - self.fail( - "While receiving with <<{name:s}>> bad data " - "<<{outdata:r}>> ({nout:d}) received; " - "expected <<{indata:r}>> ({nin:d})\n".format( - name=meth_name, outdata=outdata[:20], - nout=len(outdata), - indata=indata[:20], nin=len(indata) - ) - ) - except ValueError as e: - if expect_success: - self.fail( - "Failed to receive with method <<{name:s}>>; " - "expected to succeed.\n".format(name=meth_name) + for meth_name, recv_meth, expect_success, args in recv_methods: + indata = (data_prefix + meth_name).encode('ascii') + try: + s.send(indata) + outdata = recv_meth(*args) + if outdata != indata.lower(): + self.fail( + "While receiving with <<{name:s}>> bad data " + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], + nout=len(outdata), + indata=indata[:20], nin=len(indata) ) - if not str(e).startswith(meth_name): - self.fail( - "Method <<{name:s}>> failed with unexpected " - "exception message: {exp:s}\n".format( - name=meth_name, exp=e - ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to receive with method <<{name:s}>>; " + "expected to succeed.\n".format(name=meth_name) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<{name:s}>> failed with unexpected " + "exception message: {exp:s}\n".format( + name=meth_name, exp=e ) - # consume data - s.read() + ) + # consume data + s.read() - # read(-1, buffer) is supported, even though read(-1) is not - data = b"data" - s.send(data) - buffer = bytearray(len(data)) - self.assertEqual(s.read(-1, buffer), len(data)) - self.assertEqual(buffer, data) + # read(-1, buffer) is supported, even though read(-1) is not + data = b"data" + s.send(data) + buffer = bytearray(len(data)) + self.assertEqual(s.read(-1, buffer), len(data)) + self.assertEqual(buffer, data) - # Make sure sendmsg et al are disallowed to avoid - # inadvertent disclosure of data and/or corruption - # of the encrypted data stream - self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) - self.assertRaises(NotImplementedError, s.recvmsg, 100) - self.assertRaises(NotImplementedError, - s.recvmsg_into, bytearray(100)) + # Make sure sendmsg et al are disallowed to avoid + # inadvertent disclosure of data and/or corruption + # of the encrypted data stream + self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) + self.assertRaises(NotImplementedError, s.recvmsg, 100) + self.assertRaises(NotImplementedError, + s.recvmsg_into, bytearray(100)) - s.write(b"over\n") + s.write(b"over\n") - self.assertRaises(ValueError, s.recv, -1) - self.assertRaises(ValueError, s.read, -1) + self.assertRaises(ValueError, s.recv, -1) + self.assertRaises(ValueError, s.read, -1) - s.close() + s.close() - def test_recv_zero(self): - server = ThreadedEchoServer(CERTFILE) - server.__enter__() - self.addCleanup(server.__exit__, None, None) - s = socket.create_connection((HOST, server.port)) - self.addCleanup(s.close) - s = test_wrap_socket(s, suppress_ragged_eofs=False) - self.addCleanup(s.close) + def test_recv_zero(self): + server = ThreadedEchoServer(CERTFILE) + server.__enter__() + self.addCleanup(server.__exit__, None, None) + s = socket.create_connection((HOST, server.port)) + self.addCleanup(s.close) + s = test_wrap_socket(s, suppress_ragged_eofs=False) + self.addCleanup(s.close) - # recv/read(0) should return no data - s.send(b"data") - self.assertEqual(s.recv(0), b"") - self.assertEqual(s.read(0), b"") - self.assertEqual(s.read(), b"data") + # recv/read(0) should return no data + s.send(b"data") + self.assertEqual(s.recv(0), b"") + self.assertEqual(s.read(0), b"") + self.assertEqual(s.read(), b"data") + + # Should not block if the other end sends no data + s.setblocking(False) + self.assertEqual(s.recv(0), b"") + self.assertEqual(s.recv_into(bytearray()), 0) - # Should not block if the other end sends no data + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) s.setblocking(False) - self.assertEqual(s.recv(0), b"") - self.assertEqual(s.recv_into(bytearray()), 0) - - def test_nonblocking_send(self): - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - s.setblocking(False) - - # If we keep sending data, at some point the buffers - # will be full and the call will block - buf = bytearray(8192) - def fill_buffer(): - while True: - s.send(buf) - self.assertRaises((ssl.SSLWantWriteError, - ssl.SSLWantReadError), fill_buffer) - - # Now read all the output and discard it - s.setblocking(True) - s.close() - def test_handshake_timeout(self): - # Issue #5103: SSL handshake must respect the socket timeout - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = support.bind_port(server) - started = threading.Event() - finish = False - - def serve(): - server.listen() - started.set() - conns = [] - while not finish: - r, w, e = select.select([server], [], [], 0.1) - if server in r: - # Let the socket hang around rather than having - # it closed by garbage collection. - conns.append(server.accept()[0]) - for sock in conns: - sock.close() - - t = threading.Thread(target=serve) - t.start() - started.wait() + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = support.bind_port(server) + started = threading.Event() + finish = False + + def serve(): + server.listen() + started.set() + conns = [] + while not finish: + r, w, e = select.select([server], [], [], 0.1) + if server in r: + # Let the socket hang around rather than having + # it closed by garbage collection. + conns.append(server.accept()[0]) + for sock in conns: + sock.close() + + t = threading.Thread(target=serve) + t.start() + started.wait() + try: try: - try: - c = socket.socket(socket.AF_INET) - c.settimeout(0.2) - c.connect((host, port)) - # Will attempt handshake and time out - self.assertRaisesRegex(socket.timeout, "timed out", - test_wrap_socket, c) - finally: - c.close() - try: - c = socket.socket(socket.AF_INET) - c = test_wrap_socket(c) - c.settimeout(0.2) - # Will attempt handshake and time out - self.assertRaisesRegex(socket.timeout, "timed out", - c.connect, (host, port)) - finally: - c.close() + c = socket.socket(socket.AF_INET) + c.settimeout(0.2) + c.connect((host, port)) + # Will attempt handshake and time out + self.assertRaisesRegex(socket.timeout, "timed out", + test_wrap_socket, c) finally: - finish = True - t.join() - server.close() - - def test_server_accept(self): - # Issue #16357: accept() on a SSLSocket created through - # SSLContext.wrap_socket(). - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = support.bind_port(server) - server = context.wrap_socket(server, server_side=True) - self.assertTrue(server.server_side) - - evt = threading.Event() - remote = None - peer = None - def serve(): - nonlocal remote, peer - server.listen() - # Block on the accept and wait on the connection to close. - evt.set() - remote, peer = server.accept() - remote.recv(1) - - t = threading.Thread(target=serve) - t.start() - # Client wait until server setup and perform a connect. - evt.wait() - client = context.wrap_socket(socket.socket()) - client.connect((host, port)) - client_addr = client.getsockname() - client.close() + c.close() + try: + c = socket.socket(socket.AF_INET) + c = test_wrap_socket(c) + c.settimeout(0.2) + # Will attempt handshake and time out + self.assertRaisesRegex(socket.timeout, "timed out", + c.connect, (host, port)) + finally: + c.close() + finally: + finish = True t.join() - remote.close() server.close() - # Sanity checks. - self.assertIsInstance(remote, ssl.SSLSocket) - self.assertEqual(peer, client_addr) - - def test_getpeercert_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - with context.wrap_socket(socket.socket()) as sock: - with self.assertRaises(OSError) as cm: - sock.getpeercert() - self.assertEqual(cm.exception.errno, errno.ENOTCONN) - - def test_do_handshake_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - with context.wrap_socket(socket.socket()) as sock: - with self.assertRaises(OSError) as cm: - sock.do_handshake() - self.assertEqual(cm.exception.errno, errno.ENOTCONN) - - def test_default_ciphers(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - try: - # Force a set of weak ciphers on our client context - context.set_ciphers("DES") - except ssl.SSLError: - self.skipTest("no DES cipher available") - with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_SSLv23, - chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: - with self.assertRaises(OSError): - s.connect((HOST, server.port)) - self.assertIn("no shared cipher", server.conn_errors[0]) - - def test_version_basic(self): - """ - Basic tests for SSLSocket.version(). - More tests are done in the test_protocol_*() methods. - """ - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, - chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: - self.assertIs(s.version(), None) - s.connect((HOST, server.port)) - self.assertEqual(s.version(), 'TLSv1') - self.assertIs(s.version(), None) - @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") - def test_default_ecdh_curve(self): - # Issue #21015: elliptic curve-based Diffie Hellman key exchange - # should be enabled by default on SSL contexts. - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.load_cert_chain(CERTFILE) - # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled - # explicitly using the 'ECCdraft' cipher alias. Otherwise, - # our default cipher list should prefer ECDH-based ciphers - # automatically. - if ssl.OPENSSL_VERSION_INFO < (1, 0, 0): - context.set_ciphers("ECCdraft:ECDH") - with ThreadedEchoServer(context=context) as server: - with context.wrap_socket(socket.socket()) as s: + def test_server_accept(self): + # Issue #16357: accept() on a SSLSocket created through + # SSLContext.wrap_socket(). + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = support.bind_port(server) + server = context.wrap_socket(server, server_side=True) + self.assertTrue(server.server_side) + + evt = threading.Event() + remote = None + peer = None + def serve(): + nonlocal remote, peer + server.listen() + # Block on the accept and wait on the connection to close. + evt.set() + remote, peer = server.accept() + remote.recv(1) + + t = threading.Thread(target=serve) + t.start() + # Client wait until server setup and perform a connect. + evt.wait() + client = context.wrap_socket(socket.socket()) + client.connect((host, port)) + client_addr = client.getsockname() + client.close() + t.join() + remote.close() + server.close() + # Sanity checks. + self.assertIsInstance(remote, ssl.SSLSocket) + self.assertEqual(peer, client_addr) + + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_default_ciphers(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + try: + # Force a set of weak ciphers on our client context + context.set_ciphers("DES") + except ssl.SSLError: + self.skipTest("no DES cipher available") + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_SSLv23, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaises(OSError): s.connect((HOST, server.port)) - self.assertIn("ECDH", s.cipher()[0]) + self.assertIn("no shared cipher", server.conn_errors[0]) - @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, - "'tls-unique' channel binding not available") - def test_tls_unique_channel_binding(self): - """Test tls-unique channel binding.""" - if support.verbose: - sys.stdout.write("\n") - - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) s.connect((HOST, server.port)) - # get the data - cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got channel binding data: {0!r}\n" - .format(cb_data)) - - # check if it is sane - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - - # and compare with the peers version - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(cb_data).encode("us-ascii")) - s.close() - - # now, again - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + self.assertEqual(s.version(), 'TLSv1') + self.assertIs(s.version(), None) + + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") + def test_default_ecdh_curve(self): + # Issue #21015: elliptic curve-based Diffie Hellman key exchange + # should be enabled by default on SSL contexts. + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.load_cert_chain(CERTFILE) + # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled + # explicitly using the 'ECCdraft' cipher alias. Otherwise, + # our default cipher list should prefer ECDH-based ciphers + # automatically. + if ssl.OPENSSL_VERSION_INFO < (1, 0, 0): + context.set_ciphers("ECCdraft:ECDH") + with ThreadedEchoServer(context=context) as server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - new_cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got another channel binding data: {0!r}\n" - .format(new_cb_data)) - # is it really unique - self.assertNotEqual(cb_data, new_cb_data) - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(new_cb_data).encode("us-ascii")) - s.close() + self.assertIn("ECDH", s.cipher()[0]) - def test_compression(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - if support.verbose: - sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) - self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) - - @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), - "ssl.OP_NO_COMPRESSION needed for this test") - def test_compression_disabled(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.options |= ssl.OP_NO_COMPRESSION - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['compression'], None) - - def test_dh_params(self): - # Check we can get a connection with ephemeral Diffie-Hellman - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.load_dh_params(DHFILE) - context.set_ciphers("kEDH") - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - cipher = stats["cipher"][0] - parts = cipher.split("-") - if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: - self.fail("Non-DH cipher: " + cipher[0]) - - def test_selected_alpn_protocol(self): - # selected_alpn_protocol() is None unless ALPN is used. - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_alpn_protocol'], None) + @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, + "'tls-unique' channel binding not available") + def test_tls_unique_channel_binding(self): + """Test tls-unique channel binding.""" + if support.verbose: + sys.stdout.write("\n") - @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") - def test_selected_alpn_protocol_if_server_uses_alpn(self): - # selected_alpn_protocol() is None unless ALPN is used by the client. - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_verify_locations(CERTFILE) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # get the data + cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got channel binding data: {0!r}\n" + .format(cb_data)) + + # check if it is sane + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + + # and compare with the peers version + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(cb_data).encode("us-ascii")) + s.close() + + # now, again + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + new_cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got another channel binding data: {0!r}\n" + .format(new_cb_data)) + # is it really unique + self.assertNotEqual(cb_data, new_cb_data) + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(new_cb_data).encode("us-ascii")) + s.close() + + def test_compression(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + if support.verbose: + sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) + self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) + + @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), + "ssl.OP_NO_COMPRESSION needed for this test") + def test_compression_disabled(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.options |= ssl.OP_NO_COMPRESSION + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['compression'], None) + + def test_dh_params(self): + # Check we can get a connection with ephemeral Diffie-Hellman + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.load_dh_params(DHFILE) + context.set_ciphers("kEDH") + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + cipher = stats["cipher"][0] + parts = cipher.split("-") + if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: + self.fail("Non-DH cipher: " + cipher[0]) + + def test_selected_alpn_protocol(self): + # selected_alpn_protocol() is None unless ALPN is used. + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol_if_server_uses_alpn(self): + # selected_alpn_protocol() is None unless ALPN is used by the client. + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_verify_locations(CERTFILE) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(['foo', 'bar']) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") + def test_alpn_protocols(self): + server_protocols = ['foo', 'bar', 'milkshake'] + protocol_tests = [ + (['foo', 'bar'], 'foo'), + (['bar', 'foo'], 'foo'), + (['milkshake'], 'milkshake'), + (['http/3.0', 'http/4.0'], None) + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) server_context.load_cert_chain(CERTFILE) - server_context.set_alpn_protocols(['foo', 'bar']) - stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_alpn_protocol'], None) - - @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") - def test_alpn_protocols(self): - server_protocols = ['foo', 'bar', 'milkshake'] - protocol_tests = [ - (['foo', 'bar'], 'foo'), - (['bar', 'foo'], 'foo'), - (['milkshake'], 'milkshake'), - (['http/3.0', 'http/4.0'], None) - ] - for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - server_context.load_cert_chain(CERTFILE) - server_context.set_alpn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - client_context.load_cert_chain(CERTFILE) - client_context.set_alpn_protocols(client_protocols) + server_context.set_alpn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + client_context.load_cert_chain(CERTFILE) + client_context.set_alpn_protocols(client_protocols) - try: - stats = server_params_test(client_context, - server_context, - chatty=True, - connectionchatty=True) - except ssl.SSLError as e: - stats = e - - if (expected is None and IS_OPENSSL_1_1 - and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)): - # OpenSSL 1.1.0 to 1.1.0e raises handshake error - self.assertIsInstance(stats, ssl.SSLError) - else: - msg = "failed trying %s (s) and %s (c).\n" \ - "was expecting %s, but got %%s from the %%s" \ - % (str(server_protocols), str(client_protocols), - str(expected)) - client_result = stats['client_alpn_protocol'] - self.assertEqual(client_result, expected, - msg % (client_result, "client")) - server_result = stats['server_alpn_protocols'][-1] \ - if len(stats['server_alpn_protocols']) else 'nothing' - self.assertEqual(server_result, expected, - msg % (server_result, "server")) - - def test_selected_npn_protocol(self): - # selected_npn_protocol() is None unless NPN is used - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_npn_protocol'], None) - - @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") - def test_npn_protocols(self): - server_protocols = ['http/1.1', 'spdy/2'] - protocol_tests = [ - (['http/1.1', 'spdy/2'], 'http/1.1'), - (['spdy/2', 'http/1.1'], 'http/1.1'), - (['spdy/2', 'test'], 'spdy/2'), - (['abc', 'def'], 'abc') - ] - for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(CERTFILE) - server_context.set_npn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_cert_chain(CERTFILE) - client_context.set_npn_protocols(client_protocols) - stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) + try: + stats = server_params_test(client_context, + server_context, + chatty=True, + connectionchatty=True) + except ssl.SSLError as e: + stats = e + if (expected is None and IS_OPENSSL_1_1 + and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)): + # OpenSSL 1.1.0 to 1.1.0e raises handshake error + self.assertIsInstance(stats, ssl.SSLError) + else: msg = "failed trying %s (s) and %s (c).\n" \ - "was expecting %s, but got %%s from the %%s" \ - % (str(server_protocols), str(client_protocols), - str(expected)) - client_result = stats['client_npn_protocol'] - self.assertEqual(client_result, expected, msg % (client_result, "client")) - server_result = stats['server_npn_protocols'][-1] \ - if len(stats['server_npn_protocols']) else 'nothing' - self.assertEqual(server_result, expected, msg % (server_result, "server")) - - def sni_contexts(self): + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_alpn_protocol'] + self.assertEqual(client_result, expected, + msg % (client_result, "client")) + server_result = stats['server_alpn_protocols'][-1] \ + if len(stats['server_alpn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, + msg % (server_result, "server")) + + def test_selected_npn_protocol(self): + # selected_npn_protocol() is None unless NPN is used + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_npn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") + def test_npn_protocols(self): + server_protocols = ['http/1.1', 'spdy/2'] + protocol_tests = [ + (['http/1.1', 'spdy/2'], 'http/1.1'), + (['spdy/2', 'http/1.1'], 'http/1.1'), + (['spdy/2', 'test'], 'spdy/2'), + (['abc', 'def'], 'abc') + ] + for client_protocols, expected in protocol_tests: server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - other_context.load_cert_chain(SIGNED_CERTFILE2) + server_context.load_cert_chain(CERTFILE) + server_context.set_npn_protocols(server_protocols) client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - return server_context, other_context, client_context + client_context.load_cert_chain(CERTFILE) + client_context.set_npn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) - def check_common_name(self, stats, name): - cert = stats['peercert'] - self.assertIn((('commonName', name),), cert['subject']) + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_npn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_npn_protocols'][-1] \ + if len(stats['server_npn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + + def sni_contexts(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context.load_cert_chain(SIGNED_CERTFILE2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + + def check_common_name(self, stats, name): + cert = stats['peercert'] + self.assertIn((('commonName', name),), cert['subject']) + + @needs_sni + def test_sni_callback(self): + calls = [] + server_context, other_context, client_context = self.sni_contexts() + + def servername_cb(ssl_sock, server_name, initial_context): + calls.append((server_name, initial_context)) + if server_name is not None: + ssl_sock.context = other_context + server_context.set_servername_callback(servername_cb) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='supermessage') + # The hostname was fetched properly, and the certificate was + # changed for the connection. + self.assertEqual(calls, [("supermessage", server_context)]) + # CERTFILE4 was selected + self.check_common_name(stats, 'fakehostname') + + calls = [] + # The callback is called with server_name=None + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) + self.check_common_name(stats, 'localhost') + + # Check disabling the callback + calls = [] + server_context.set_servername_callback(None) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='notfunny') + # Certificate didn't change + self.check_common_name(stats, 'localhost') + self.assertEqual(calls, []) - @needs_sni - def test_sni_callback(self): - calls = [] - server_context, other_context, client_context = self.sni_contexts() + @needs_sni + def test_sni_callback_alert(self): + # Returning a TLS alert is reflected to the connecting client + server_context, other_context, client_context = self.sni_contexts() - def servername_cb(ssl_sock, server_name, initial_context): - calls.append((server_name, initial_context)) - if server_name is not None: - ssl_sock.context = other_context - server_context.set_servername_callback(servername_cb) + def cb_returning_alert(ssl_sock, server_name, initial_context): + return ssl.ALERT_DESCRIPTION_ACCESS_DENIED + server_context.set_servername_callback(cb_returning_alert) + with self.assertRaises(ssl.SSLError) as cm: stats = server_params_test(client_context, server_context, - chatty=True, + chatty=False, sni_name='supermessage') - # The hostname was fetched properly, and the certificate was - # changed for the connection. - self.assertEqual(calls, [("supermessage", server_context)]) - # CERTFILE4 was selected - self.check_common_name(stats, 'fakehostname') - - calls = [] - # The callback is called with server_name=None + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') + + @needs_sni + def test_sni_callback_raising(self): + # Raising fails the connection with a TLS handshake failure alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_raising(ssl_sock, server_name, initial_context): + 1/0 + server_context.set_servername_callback(cb_raising) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: stats = server_params_test(client_context, server_context, - chatty=True, - sni_name=None) - self.assertEqual(calls, [(None, server_context)]) - self.check_common_name(stats, 'localhost') + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') + self.assertIn("ZeroDivisionError", stderr.getvalue()) + + @needs_sni + def test_sni_callback_wrong_return_type(self): + # Returning the wrong return type terminates the TLS connection + # with an internal error alert. + server_context, other_context, client_context = self.sni_contexts() - # Check disabling the callback - calls = [] - server_context.set_servername_callback(None) + def cb_wrong_return_type(ssl_sock, server_name, initial_context): + return "foo" + server_context.set_servername_callback(cb_wrong_return_type) + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: stats = server_params_test(client_context, server_context, - chatty=True, - sni_name='notfunny') - # Certificate didn't change - self.check_common_name(stats, 'localhost') - self.assertEqual(calls, []) - - @needs_sni - def test_sni_callback_alert(self): - # Returning a TLS alert is reflected to the connecting client - server_context, other_context, client_context = self.sni_contexts() - - def cb_returning_alert(ssl_sock, server_name, initial_context): - return ssl.ALERT_DESCRIPTION_ACCESS_DENIED - server_context.set_servername_callback(cb_returning_alert) - - with self.assertRaises(ssl.SSLError) as cm: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') - - @needs_sni - def test_sni_callback_raising(self): - # Raising fails the connection with a TLS handshake failure alert. - server_context, other_context, client_context = self.sni_contexts() - - def cb_raising(ssl_sock, server_name, initial_context): - 1/0 - server_context.set_servername_callback(cb_raising) - - with self.assertRaises(ssl.SSLError) as cm, \ - support.captured_stderr() as stderr: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') - self.assertIn("ZeroDivisionError", stderr.getvalue()) - - @needs_sni - def test_sni_callback_wrong_return_type(self): - # Returning the wrong return type terminates the TLS connection - # with an internal error alert. - server_context, other_context, client_context = self.sni_contexts() - - def cb_wrong_return_type(ssl_sock, server_name, initial_context): - return "foo" - server_context.set_servername_callback(cb_wrong_return_type) - - with self.assertRaises(ssl.SSLError) as cm, \ - support.captured_stderr() as stderr: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') - self.assertIn("TypeError", stderr.getvalue()) - - def test_shared_ciphers(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): - client_context.set_ciphers("AES128:AES256") - server_context.set_ciphers("AES256") - alg1 = "AES256" - alg2 = "AES-256" - else: - client_context.set_ciphers("AES:3DES") - server_context.set_ciphers("3DES") - alg1 = "3DES" - alg2 = "DES-CBC3" - - stats = server_params_test(client_context, server_context) - ciphers = stats['server_shared_ciphers'][0] - self.assertGreater(len(ciphers), 0) - for name, tls_version, bits in ciphers: - if not alg1 in name.split("-") and alg2 not in name: - self.fail(name) - - def test_read_write_after_close_raises_valuerror(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - - with server: - s = context.wrap_socket(socket.socket()) + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') + self.assertIn("TypeError", stderr.getvalue()) + + def test_shared_ciphers(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): + client_context.set_ciphers("AES128:AES256") + server_context.set_ciphers("AES256") + alg1 = "AES256" + alg2 = "AES-256" + else: + client_context.set_ciphers("AES:3DES") + server_context.set_ciphers("3DES") + alg1 = "3DES" + alg2 = "DES-CBC3" + + stats = server_params_test(client_context, server_context) + ciphers = stats['server_shared_ciphers'][0] + self.assertGreater(len(ciphers), 0) + for name, tls_version, bits in ciphers: + if not alg1 in name.split("-") and alg2 not in name: + self.fail(name) + + def test_read_write_after_close_raises_valuerror(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + + with server: + s = context.wrap_socket(socket.socket()) + s.connect((HOST, server.port)) + s.close() + + self.assertRaises(ValueError, s.read, 1024) + self.assertRaises(ValueError, s.write, b'hello') + + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - s.close() + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + + def test_session(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + + # first connection without session + stats = server_params_test(client_context, server_context) + session = stats['session'] + self.assertTrue(session.id) + self.assertGreater(session.time, 0) + self.assertGreater(session.timeout, 0) + self.assertTrue(session.has_ticket) + if ssl.OPENSSL_VERSION_INFO > (1, 0, 1): + self.assertGreater(session.ticket_lifetime_hint, 0) + self.assertFalse(stats['session_reused']) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 1) + self.assertEqual(sess_stat['hits'], 0) + + # reuse session + stats = server_params_test(client_context, server_context, session=session) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 2) + self.assertEqual(sess_stat['hits'], 1) + self.assertTrue(stats['session_reused']) + session2 = stats['session'] + self.assertEqual(session2.id, session.id) + self.assertEqual(session2, session) + self.assertIsNot(session2, session) + self.assertGreaterEqual(session2.time, session.time) + self.assertGreaterEqual(session2.timeout, session.timeout) + + # another one without session + stats = server_params_test(client_context, server_context) + self.assertFalse(stats['session_reused']) + session3 = stats['session'] + self.assertNotEqual(session3.id, session.id) + self.assertNotEqual(session3, session) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 3) + self.assertEqual(sess_stat['hits'], 1) + + # reuse session again + stats = server_params_test(client_context, server_context, session=session) + self.assertTrue(stats['session_reused']) + session4 = stats['session'] + self.assertEqual(session4.id, session.id) + self.assertEqual(session4, session) + self.assertGreaterEqual(session4.time, session.time) + self.assertGreaterEqual(session4.timeout, session.timeout) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 4) + self.assertEqual(sess_stat['hits'], 2) + + def test_session_handling(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + + context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context2.verify_mode = ssl.CERT_REQUIRED + context2.load_verify_locations(CERTFILE) + context2.load_cert_chain(CERTFILE) + + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + # session is None before handshake + self.assertEqual(s.session, None) + self.assertEqual(s.session_reused, None) + s.connect((HOST, server.port)) + session = s.session + self.assertTrue(session) + with self.assertRaises(TypeError) as e: + s.session = object + self.assertEqual(str(e.exception), 'Value is not a SSLSession.') - self.assertRaises(ValueError, s.read, 1024) - self.assertRaises(ValueError, s.write, b'hello') - - def test_sendfile(self): - TEST_DATA = b"x" * 512 - with open(support.TESTFN, 'wb') as f: - f.write(TEST_DATA) - self.addCleanup(support.unlink, support.TESTFN) - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - with server: - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - with open(support.TESTFN, 'rb') as file: - s.sendfile(file) - self.assertEqual(s.recv(1024), TEST_DATA) + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + # cannot set session after handshake + with self.assertRaises(ValueError) as e: + s.session = session + self.assertEqual(str(e.exception), + 'Cannot set session after handshake.') - def test_session(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - - # first connection without session - stats = server_params_test(client_context, server_context) - session = stats['session'] - self.assertTrue(session.id) - self.assertGreater(session.time, 0) - self.assertGreater(session.timeout, 0) - self.assertTrue(session.has_ticket) - if ssl.OPENSSL_VERSION_INFO > (1, 0, 1): - self.assertGreater(session.ticket_lifetime_hint, 0) - self.assertFalse(stats['session_reused']) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 1) - self.assertEqual(sess_stat['hits'], 0) - - # reuse session - stats = server_params_test(client_context, server_context, session=session) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 2) - self.assertEqual(sess_stat['hits'], 1) - self.assertTrue(stats['session_reused']) - session2 = stats['session'] - self.assertEqual(session2.id, session.id) - self.assertEqual(session2, session) - self.assertIsNot(session2, session) - self.assertGreaterEqual(session2.time, session.time) - self.assertGreaterEqual(session2.timeout, session.timeout) - - # another one without session - stats = server_params_test(client_context, server_context) - self.assertFalse(stats['session_reused']) - session3 = stats['session'] - self.assertNotEqual(session3.id, session.id) - self.assertNotEqual(session3, session) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 3) - self.assertEqual(sess_stat['hits'], 1) - - # reuse session again - stats = server_params_test(client_context, server_context, session=session) - self.assertTrue(stats['session_reused']) - session4 = stats['session'] - self.assertEqual(session4.id, session.id) - self.assertEqual(session4, session) - self.assertGreaterEqual(session4.time, session.time) - self.assertGreaterEqual(session4.timeout, session.timeout) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 4) - self.assertEqual(sess_stat['hits'], 2) - - def test_session_handling(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - - context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context2.verify_mode = ssl.CERT_REQUIRED - context2.load_verify_locations(CERTFILE) - context2.load_cert_chain(CERTFILE) - - server = ThreadedEchoServer(context=context, chatty=False) - with server: - with context.wrap_socket(socket.socket()) as s: - # session is None before handshake - self.assertEqual(s.session, None) - self.assertEqual(s.session_reused, None) - s.connect((HOST, server.port)) - session = s.session - self.assertTrue(session) - with self.assertRaises(TypeError) as e: - s.session = object - self.assertEqual(str(e.exception), 'Value is not a SSLSession.') + with context.wrap_socket(socket.socket()) as s: + # can set session before handshake and before the + # connection was established + s.session = session + s.connect((HOST, server.port)) + self.assertEqual(s.session.id, session.id) + self.assertEqual(s.session, session) + self.assertEqual(s.session_reused, True) - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - # cannot set session after handshake - with self.assertRaises(ValueError) as e: - s.session = session - self.assertEqual(str(e.exception), - 'Cannot set session after handshake.') - - with context.wrap_socket(socket.socket()) as s: - # can set session before handshake and before the - # connection was established + with context2.wrap_socket(socket.socket()) as s: + # cannot re-use session with a different SSLContext + with self.assertRaises(ValueError) as e: s.session = session s.connect((HOST, server.port)) - self.assertEqual(s.session.id, session.id) - self.assertEqual(s.session, session) - self.assertEqual(s.session_reused, True) - - with context2.wrap_socket(socket.socket()) as s: - # cannot re-use session with a different SSLContext - with self.assertRaises(ValueError) as e: - s.session = session - s.connect((HOST, server.port)) - self.assertEqual(str(e.exception), - 'Session refers to a different SSLContext.') + self.assertEqual(str(e.exception), + 'Session refers to a different SSLContext.') def test_main(verbose=False): @@ -3610,22 +3603,17 @@ def test_main(verbose=False): tests = [ ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests, - SimpleBackgroundTests, + SimpleBackgroundTests, ThreadedTests, ] if support.is_resource_enabled('network'): tests.append(NetworkedTests) - if _have_threads: - thread_info = support.threading_setup() - if thread_info: - tests.append(ThreadedTests) - + thread_info = support.threading_setup() try: support.run_unittest(*tests) finally: - if _have_threads: - support.threading_cleanup(*thread_info) + support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main() |