summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
authorAntoine Pitrou <pitrou@free.fr>2017-09-07 16:56:24 (GMT)
committerVictor Stinner <victor.stinner@gmail.com>2017-09-07 16:56:24 (GMT)
commita6a4dc816d68df04a7d592e0b6af8c7ecc4d4344 (patch)
tree1c31738009bee903417cea928e705a112aea2392 /Lib/test/test_ssl.py
parent1f06a680de465be0c24a78ea3b610053955daa99 (diff)
downloadcpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.zip
cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.gz
cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.bz2
bpo-31370: Remove support for threads-less builds (#3385)
* Remove Setup.config * Always define WITH_THREAD for compatibility.
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py3222
1 files changed, 1605 insertions, 1617 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 16cad9d..89b4609 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -12,6 +12,7 @@ import os
import errno
import pprint
import urllib.request
+import threading
import traceback
import asyncore
import weakref
@@ -20,12 +21,6 @@ import functools
ssl = support.import_module("ssl")
-try:
- import threading
-except ImportError:
- _have_threads = False
-else:
- _have_threads = True
PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST
@@ -1468,7 +1463,6 @@ class MemoryBIOTests(unittest.TestCase):
self.assertRaises(TypeError, bio.write, 1)
-@unittest.skipUnless(_have_threads, "Needs threading module")
class SimpleBackgroundTests(unittest.TestCase):
"""Tests that connect to a simple server running in the background"""
@@ -1828,1744 +1822,1743 @@ def _test_get_server_certificate_fail(test, host, port):
test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
-if _have_threads:
- from test.ssl_servers import make_https_server
+from test.ssl_servers import make_https_server
- class ThreadedEchoServer(threading.Thread):
+class ThreadedEchoServer(threading.Thread):
- class ConnectionHandler(threading.Thread):
+ class ConnectionHandler(threading.Thread):
- """A mildly complicated class, because we want it to work both
- with and without the SSL wrapper around the socket connection, so
- that we can test the STARTTLS functionality."""
+ """A mildly complicated class, because we want it to work both
+ with and without the SSL wrapper around the socket connection, so
+ that we can test the STARTTLS functionality."""
- def __init__(self, server, connsock, addr):
- self.server = server
+ def __init__(self, server, connsock, addr):
+ self.server = server
+ self.running = False
+ self.sock = connsock
+ self.addr = addr
+ self.sock.setblocking(1)
+ self.sslconn = None
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def wrap_conn(self):
+ try:
+ self.sslconn = self.server.context.wrap_socket(
+ self.sock, server_side=True)
+ self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
+ except (ssl.SSLError, ConnectionResetError, OSError) as e:
+ # We treat ConnectionResetError as though it were an
+ # SSLError - OpenSSL on Ubuntu abruptly closes the
+ # connection when asked to use an unsupported protocol.
+ #
+ # OSError may occur with wrong protocols, e.g. both
+ # sides use PROTOCOL_TLS_SERVER.
+ #
+ # XXX Various errors can have happened here, for example
+ # a mismatching protocol version, an invalid certificate,
+ # or a low-level bug. This should be made more discriminating.
+ #
+ # bpo-31323: Store the exception as string to prevent
+ # a reference leak: server -> conn_errors -> exception
+ # -> traceback -> self (ConnectionHandler) -> server
+ self.server.conn_errors.append(str(e))
+ if self.server.chatty:
+ handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
self.running = False
- self.sock = connsock
- self.addr = addr
- self.sock.setblocking(1)
- self.sslconn = None
- threading.Thread.__init__(self)
- self.daemon = True
-
- def wrap_conn(self):
- try:
- self.sslconn = self.server.context.wrap_socket(
- self.sock, server_side=True)
- self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
- self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
- except (ssl.SSLError, ConnectionResetError, OSError) as e:
- # We treat ConnectionResetError as though it were an
- # SSLError - OpenSSL on Ubuntu abruptly closes the
- # connection when asked to use an unsupported protocol.
- #
- # OSError may occur with wrong protocols, e.g. both
- # sides use PROTOCOL_TLS_SERVER.
- #
- # XXX Various errors can have happened here, for example
- # a mismatching protocol version, an invalid certificate,
- # or a low-level bug. This should be made more discriminating.
- #
- # bpo-31323: Store the exception as string to prevent
- # a reference leak: server -> conn_errors -> exception
- # -> traceback -> self (ConnectionHandler) -> server
- self.server.conn_errors.append(str(e))
- if self.server.chatty:
- handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
- self.running = False
- self.server.stop()
- self.close()
- return False
- else:
- self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
- if self.server.context.verify_mode == ssl.CERT_REQUIRED:
- cert = self.sslconn.getpeercert()
- if support.verbose and self.server.chatty:
- sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
- cert_binary = self.sslconn.getpeercert(True)
- if support.verbose and self.server.chatty:
- sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
- cipher = self.sslconn.cipher()
+ self.server.stop()
+ self.close()
+ return False
+ else:
+ self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
+ if self.server.context.verify_mode == ssl.CERT_REQUIRED:
+ cert = self.sslconn.getpeercert()
if support.verbose and self.server.chatty:
- sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
- sys.stdout.write(" server: selected protocol is now "
- + str(self.sslconn.selected_npn_protocol()) + "\n")
- return True
-
- def read(self):
- if self.sslconn:
- return self.sslconn.read()
- else:
- return self.sock.recv(1024)
+ sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
+ cert_binary = self.sslconn.getpeercert(True)
+ if support.verbose and self.server.chatty:
+ sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
+ cipher = self.sslconn.cipher()
+ if support.verbose and self.server.chatty:
+ sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+ sys.stdout.write(" server: selected protocol is now "
+ + str(self.sslconn.selected_npn_protocol()) + "\n")
+ return True
+
+ def read(self):
+ if self.sslconn:
+ return self.sslconn.read()
+ else:
+ return self.sock.recv(1024)
- def write(self, bytes):
- if self.sslconn:
- return self.sslconn.write(bytes)
- else:
- return self.sock.send(bytes)
+ def write(self, bytes):
+ if self.sslconn:
+ return self.sslconn.write(bytes)
+ else:
+ return self.sock.send(bytes)
- def close(self):
- if self.sslconn:
- self.sslconn.close()
- else:
- self.sock.close()
+ def close(self):
+ if self.sslconn:
+ self.sslconn.close()
+ else:
+ self.sock.close()
- def run(self):
- self.running = True
- if not self.server.starttls_server:
- if not self.wrap_conn():
- return
- while self.running:
- try:
- msg = self.read()
- stripped = msg.strip()
- if not stripped:
- # eof, so quit this handler
- self.running = False
- try:
- self.sock = self.sslconn.unwrap()
- except OSError:
- # Many tests shut the TCP connection down
- # without an SSL shutdown. This causes
- # unwrap() to raise OSError with errno=0!
- pass
- else:
- self.sslconn = None
- self.close()
- elif stripped == b'over':
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: client closed connection\n")
- self.close()
- return
- elif (self.server.starttls_server and
- stripped == b'STARTTLS'):
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
- self.write(b"OK\n")
- if not self.wrap_conn():
- return
- elif (self.server.starttls_server and self.sslconn
- and stripped == b'ENDTLS'):
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
- self.write(b"OK\n")
+ def run(self):
+ self.running = True
+ if not self.server.starttls_server:
+ if not self.wrap_conn():
+ return
+ while self.running:
+ try:
+ msg = self.read()
+ stripped = msg.strip()
+ if not stripped:
+ # eof, so quit this handler
+ self.running = False
+ try:
self.sock = self.sslconn.unwrap()
- self.sslconn = None
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: connection is now unencrypted...\n")
- elif stripped == b'CB tls-unique':
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
- data = self.sslconn.get_channel_binding("tls-unique")
- self.write(repr(data).encode("us-ascii") + b"\n")
+ except OSError:
+ # Many tests shut the TCP connection down
+ # without an SSL shutdown. This causes
+ # unwrap() to raise OSError with errno=0!
+ pass
else:
- if (support.verbose and
- self.server.connectionchatty):
- ctype = (self.sslconn and "encrypted") or "unencrypted"
- sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
- % (msg, ctype, msg.lower(), ctype))
- self.write(msg.lower())
- except OSError:
- if self.server.chatty:
- handle_error("Test server failure:\n")
+ self.sslconn = None
self.close()
- self.running = False
- # normally, we'd just stop here, but for the test
- # harness, we want to stop the server
- self.server.stop()
-
- def __init__(self, certificate=None, ssl_version=None,
- certreqs=None, cacerts=None,
- chatty=True, connectionchatty=False, starttls_server=False,
- npn_protocols=None, alpn_protocols=None,
- ciphers=None, context=None):
- if context:
- self.context = context
- else:
- self.context = ssl.SSLContext(ssl_version
- if ssl_version is not None
- else ssl.PROTOCOL_TLSv1)
- self.context.verify_mode = (certreqs if certreqs is not None
- else ssl.CERT_NONE)
- if cacerts:
- self.context.load_verify_locations(cacerts)
- if certificate:
- self.context.load_cert_chain(certificate)
- if npn_protocols:
- self.context.set_npn_protocols(npn_protocols)
- if alpn_protocols:
- self.context.set_alpn_protocols(alpn_protocols)
- if ciphers:
- self.context.set_ciphers(ciphers)
- self.chatty = chatty
- self.connectionchatty = connectionchatty
- self.starttls_server = starttls_server
- self.sock = socket.socket()
- self.port = support.bind_port(self.sock)
- self.flag = None
- self.active = False
- self.selected_npn_protocols = []
- self.selected_alpn_protocols = []
- self.shared_ciphers = []
- self.conn_errors = []
- threading.Thread.__init__(self)
- self.daemon = True
-
- def __enter__(self):
- self.start(threading.Event())
- self.flag.wait()
- return self
-
- def __exit__(self, *args):
- self.stop()
- self.join()
-
- def start(self, flag=None):
- self.flag = flag
- threading.Thread.start(self)
+ elif stripped == b'over':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: client closed connection\n")
+ self.close()
+ return
+ elif (self.server.starttls_server and
+ stripped == b'STARTTLS'):
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
+ self.write(b"OK\n")
+ if not self.wrap_conn():
+ return
+ elif (self.server.starttls_server and self.sslconn
+ and stripped == b'ENDTLS'):
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
+ self.write(b"OK\n")
+ self.sock = self.sslconn.unwrap()
+ self.sslconn = None
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: connection is now unencrypted...\n")
+ elif stripped == b'CB tls-unique':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
+ data = self.sslconn.get_channel_binding("tls-unique")
+ self.write(repr(data).encode("us-ascii") + b"\n")
+ else:
+ if (support.verbose and
+ self.server.connectionchatty):
+ ctype = (self.sslconn and "encrypted") or "unencrypted"
+ sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
+ % (msg, ctype, msg.lower(), ctype))
+ self.write(msg.lower())
+ except OSError:
+ if self.server.chatty:
+ handle_error("Test server failure:\n")
+ self.close()
+ self.running = False
+ # normally, we'd just stop here, but for the test
+ # harness, we want to stop the server
+ self.server.stop()
- def run(self):
- self.sock.settimeout(0.05)
- self.sock.listen()
- self.active = True
- if self.flag:
- # signal an event
- self.flag.set()
- while self.active:
- try:
- newconn, connaddr = self.sock.accept()
- if support.verbose and self.chatty:
- sys.stdout.write(' server: new connection from '
- + repr(connaddr) + '\n')
- handler = self.ConnectionHandler(self, newconn, connaddr)
- handler.start()
- handler.join()
- except socket.timeout:
- pass
- except KeyboardInterrupt:
- self.stop()
- self.sock.close()
+ def __init__(self, certificate=None, ssl_version=None,
+ certreqs=None, cacerts=None,
+ chatty=True, connectionchatty=False, starttls_server=False,
+ npn_protocols=None, alpn_protocols=None,
+ ciphers=None, context=None):
+ if context:
+ self.context = context
+ else:
+ self.context = ssl.SSLContext(ssl_version
+ if ssl_version is not None
+ else ssl.PROTOCOL_TLSv1)
+ self.context.verify_mode = (certreqs if certreqs is not None
+ else ssl.CERT_NONE)
+ if cacerts:
+ self.context.load_verify_locations(cacerts)
+ if certificate:
+ self.context.load_cert_chain(certificate)
+ if npn_protocols:
+ self.context.set_npn_protocols(npn_protocols)
+ if alpn_protocols:
+ self.context.set_alpn_protocols(alpn_protocols)
+ if ciphers:
+ self.context.set_ciphers(ciphers)
+ self.chatty = chatty
+ self.connectionchatty = connectionchatty
+ self.starttls_server = starttls_server
+ self.sock = socket.socket()
+ self.port = support.bind_port(self.sock)
+ self.flag = None
+ self.active = False
+ self.selected_npn_protocols = []
+ self.selected_alpn_protocols = []
+ self.shared_ciphers = []
+ self.conn_errors = []
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
+
+ def __exit__(self, *args):
+ self.stop()
+ self.join()
+
+ def start(self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.sock.settimeout(0.05)
+ self.sock.listen()
+ self.active = True
+ if self.flag:
+ # signal an event
+ self.flag.set()
+ while self.active:
+ try:
+ newconn, connaddr = self.sock.accept()
+ if support.verbose and self.chatty:
+ sys.stdout.write(' server: new connection from '
+ + repr(connaddr) + '\n')
+ handler = self.ConnectionHandler(self, newconn, connaddr)
+ handler.start()
+ handler.join()
+ except socket.timeout:
+ pass
+ except KeyboardInterrupt:
+ self.stop()
+ self.sock.close()
- def stop(self):
- self.active = False
+ def stop(self):
+ self.active = False
- class AsyncoreEchoServer(threading.Thread):
+class AsyncoreEchoServer(threading.Thread):
- # this one's based on asyncore.dispatcher
+ # this one's based on asyncore.dispatcher
- class EchoServer (asyncore.dispatcher):
+ class EchoServer (asyncore.dispatcher):
- class ConnectionHandler(asyncore.dispatcher_with_send):
+ class ConnectionHandler(asyncore.dispatcher_with_send):
- def __init__(self, conn, certfile):
- self.socket = test_wrap_socket(conn, server_side=True,
- certfile=certfile,
- do_handshake_on_connect=False)
- asyncore.dispatcher_with_send.__init__(self, self.socket)
- self._ssl_accepting = True
- self._do_ssl_handshake()
+ def __init__(self, conn, certfile):
+ self.socket = test_wrap_socket(conn, server_side=True,
+ certfile=certfile,
+ do_handshake_on_connect=False)
+ asyncore.dispatcher_with_send.__init__(self, self.socket)
+ self._ssl_accepting = True
+ self._do_ssl_handshake()
- def readable(self):
- if isinstance(self.socket, ssl.SSLSocket):
- while self.socket.pending() > 0:
- self.handle_read_event()
- return True
+ def readable(self):
+ if isinstance(self.socket, ssl.SSLSocket):
+ while self.socket.pending() > 0:
+ self.handle_read_event()
+ return True
- def _do_ssl_handshake(self):
- try:
- self.socket.do_handshake()
- except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
- return
- except ssl.SSLEOFError:
+ def _do_ssl_handshake(self):
+ try:
+ self.socket.do_handshake()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ return
+ except ssl.SSLEOFError:
+ return self.handle_close()
+ except ssl.SSLError:
+ raise
+ except OSError as err:
+ if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
- except ssl.SSLError:
- raise
- except OSError as err:
- if err.args[0] == errno.ECONNABORTED:
- return self.handle_close()
- else:
- self._ssl_accepting = False
-
- def handle_read(self):
- if self._ssl_accepting:
- self._do_ssl_handshake()
- else:
- data = self.recv(1024)
- if support.verbose:
- sys.stdout.write(" server: read %s from client\n" % repr(data))
- if not data:
- self.close()
- else:
- self.send(data.lower())
+ else:
+ self._ssl_accepting = False
- def handle_close(self):
- self.close()
+ def handle_read(self):
+ if self._ssl_accepting:
+ self._do_ssl_handshake()
+ else:
+ data = self.recv(1024)
if support.verbose:
- sys.stdout.write(" server: closed connection %s\n" % self.socket)
-
- def handle_error(self):
- raise
-
- def __init__(self, certfile):
- self.certfile = certfile
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.port = support.bind_port(sock, '')
- asyncore.dispatcher.__init__(self, sock)
- self.listen(5)
+ sys.stdout.write(" server: read %s from client\n" % repr(data))
+ if not data:
+ self.close()
+ else:
+ self.send(data.lower())
- def handle_accepted(self, sock_obj, addr):
+ def handle_close(self):
+ self.close()
if support.verbose:
- sys.stdout.write(" server: new connection from %s:%s\n" %addr)
- self.ConnectionHandler(sock_obj, self.certfile)
+ sys.stdout.write(" server: closed connection %s\n" % self.socket)
def handle_error(self):
raise
def __init__(self, certfile):
- self.flag = None
- self.active = False
- self.server = self.EchoServer(certfile)
- self.port = self.server.port
- threading.Thread.__init__(self)
- self.daemon = True
+ self.certfile = certfile
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.port = support.bind_port(sock, '')
+ asyncore.dispatcher.__init__(self, sock)
+ self.listen(5)
- def __str__(self):
- return "<%s %s>" % (self.__class__.__name__, self.server)
+ def handle_accepted(self, sock_obj, addr):
+ if support.verbose:
+ sys.stdout.write(" server: new connection from %s:%s\n" %addr)
+ self.ConnectionHandler(sock_obj, self.certfile)
- def __enter__(self):
- self.start(threading.Event())
- self.flag.wait()
- return self
+ def handle_error(self):
+ raise
- def __exit__(self, *args):
- if support.verbose:
- sys.stdout.write(" cleanup: stopping server.\n")
- self.stop()
- if support.verbose:
- sys.stdout.write(" cleanup: joining server thread.\n")
- self.join()
- if support.verbose:
- sys.stdout.write(" cleanup: successfully joined.\n")
- # make sure that ConnectionHandler is removed from socket_map
- asyncore.close_all(ignore_all=True)
+ def __init__(self, certfile):
+ self.flag = None
+ self.active = False
+ self.server = self.EchoServer(certfile)
+ self.port = self.server.port
+ threading.Thread.__init__(self)
+ self.daemon = True
- def start (self, flag=None):
- self.flag = flag
- threading.Thread.start(self)
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
- def run(self):
- self.active = True
- if self.flag:
- self.flag.set()
- while self.active:
- try:
- asyncore.loop(1)
- except:
- pass
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
- def stop(self):
- self.active = False
- self.server.close()
+ def __exit__(self, *args):
+ if support.verbose:
+ sys.stdout.write(" cleanup: stopping server.\n")
+ self.stop()
+ if support.verbose:
+ sys.stdout.write(" cleanup: joining server thread.\n")
+ self.join()
+ if support.verbose:
+ sys.stdout.write(" cleanup: successfully joined.\n")
+ # make sure that ConnectionHandler is removed from socket_map
+ asyncore.close_all(ignore_all=True)
+
+ def start (self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.active = True
+ if self.flag:
+ self.flag.set()
+ while self.active:
+ try:
+ asyncore.loop(1)
+ except:
+ pass
- def server_params_test(client_context, server_context, indata=b"FOO\n",
- chatty=True, connectionchatty=False, sni_name=None,
- session=None):
- """
- Launch a server, connect a client to it and try various reads
- and writes.
- """
- stats = {}
- server = ThreadedEchoServer(context=server_context,
- chatty=chatty,
- connectionchatty=False)
- with server:
- with client_context.wrap_socket(socket.socket(),
- server_hostname=sni_name, session=session) as s:
- s.connect((HOST, server.port))
- for arg in [indata, bytearray(indata), memoryview(indata)]:
- if connectionchatty:
- if support.verbose:
- sys.stdout.write(
- " client: sending %r...\n" % indata)
- s.write(arg)
- outdata = s.read()
- if connectionchatty:
- if support.verbose:
- sys.stdout.write(" client: read %r\n" % outdata)
- if outdata != indata.lower():
- raise AssertionError(
- "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
- % (outdata[:20], len(outdata),
- indata[:20].lower(), len(indata)))
- s.write(b"over\n")
+ def stop(self):
+ self.active = False
+ self.server.close()
+
+def server_params_test(client_context, server_context, indata=b"FOO\n",
+ chatty=True, connectionchatty=False, sni_name=None,
+ session=None):
+ """
+ Launch a server, connect a client to it and try various reads
+ and writes.
+ """
+ stats = {}
+ server = ThreadedEchoServer(context=server_context,
+ chatty=chatty,
+ connectionchatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=sni_name, session=session) as s:
+ s.connect((HOST, server.port))
+ for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- stats.update({
- 'compression': s.compression(),
- 'cipher': s.cipher(),
- 'peercert': s.getpeercert(),
- 'client_alpn_protocol': s.selected_alpn_protocol(),
- 'client_npn_protocol': s.selected_npn_protocol(),
- 'version': s.version(),
- 'session_reused': s.session_reused,
- 'session': s.session,
- })
- s.close()
- stats['server_alpn_protocols'] = server.selected_alpn_protocols
- stats['server_npn_protocols'] = server.selected_npn_protocols
- stats['server_shared_ciphers'] = server.shared_ciphers
- return stats
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ s.write(arg)
+ outdata = s.read()
+ if connectionchatty:
+ if support.verbose:
+ sys.stdout.write(" client: read %r\n" % outdata)
+ if outdata != indata.lower():
+ raise AssertionError(
+ "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+ % (outdata[:20], len(outdata),
+ indata[:20].lower(), len(indata)))
+ s.write(b"over\n")
+ if connectionchatty:
+ if support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ stats.update({
+ 'compression': s.compression(),
+ 'cipher': s.cipher(),
+ 'peercert': s.getpeercert(),
+ 'client_alpn_protocol': s.selected_alpn_protocol(),
+ 'client_npn_protocol': s.selected_npn_protocol(),
+ 'version': s.version(),
+ 'session_reused': s.session_reused,
+ 'session': s.session,
+ })
+ s.close()
+ stats['server_alpn_protocols'] = server.selected_alpn_protocols
+ stats['server_npn_protocols'] = server.selected_npn_protocols
+ stats['server_shared_ciphers'] = server.shared_ciphers
+ return stats
+
+def try_protocol_combo(server_protocol, client_protocol, expect_success,
+ certsreqs=None, server_options=0, client_options=0):
+ """
+ Try to SSL-connect using *client_protocol* to *server_protocol*.
+ If *expect_success* is true, assert that the connection succeeds,
+ if it's false, assert that the connection fails.
+ Also, if *expect_success* is a string, assert that it is the protocol
+ version actually used by the connection.
+ """
+ if certsreqs is None:
+ certsreqs = ssl.CERT_NONE
+ certtype = {
+ ssl.CERT_NONE: "CERT_NONE",
+ ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
+ ssl.CERT_REQUIRED: "CERT_REQUIRED",
+ }[certsreqs]
+ if support.verbose:
+ formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
+ sys.stdout.write(formatstr %
+ (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol),
+ certtype))
+ client_context = ssl.SSLContext(client_protocol)
+ client_context.options |= client_options
+ server_context = ssl.SSLContext(server_protocol)
+ server_context.options |= server_options
+
+ # NOTE: we must enable "ALL" ciphers on the client, otherwise an
+ # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
+ # starting from OpenSSL 1.0.0 (see issue #8322).
+ if client_context.protocol == ssl.PROTOCOL_SSLv23:
+ client_context.set_ciphers("ALL")
+
+ for ctx in (client_context, server_context):
+ ctx.verify_mode = certsreqs
+ ctx.load_cert_chain(CERTFILE)
+ ctx.load_verify_locations(CERTFILE)
+ try:
+ stats = server_params_test(client_context, server_context,
+ chatty=False, connectionchatty=False)
+ # Protocol mismatch can result in either an SSLError, or a
+ # "Connection reset by peer" error.
+ except ssl.SSLError:
+ if expect_success:
+ raise
+ except OSError as e:
+ if expect_success or e.errno != errno.ECONNRESET:
+ raise
+ else:
+ if not expect_success:
+ raise AssertionError(
+ "Client protocol %s succeeded with server protocol %s!"
+ % (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol)))
+ elif (expect_success is not True
+ and expect_success != stats['version']):
+ raise AssertionError("version mismatch: expected %r, got %r"
+ % (expect_success, stats['version']))
- def try_protocol_combo(server_protocol, client_protocol, expect_success,
- certsreqs=None, server_options=0, client_options=0):
- """
- Try to SSL-connect using *client_protocol* to *server_protocol*.
- If *expect_success* is true, assert that the connection succeeds,
- if it's false, assert that the connection fails.
- Also, if *expect_success* is a string, assert that it is the protocol
- version actually used by the connection.
- """
- if certsreqs is None:
- certsreqs = ssl.CERT_NONE
- certtype = {
- ssl.CERT_NONE: "CERT_NONE",
- ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
- ssl.CERT_REQUIRED: "CERT_REQUIRED",
- }[certsreqs]
- if support.verbose:
- formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
- sys.stdout.write(formatstr %
- (ssl.get_protocol_name(client_protocol),
- ssl.get_protocol_name(server_protocol),
- certtype))
- client_context = ssl.SSLContext(client_protocol)
- client_context.options |= client_options
- server_context = ssl.SSLContext(server_protocol)
- server_context.options |= server_options
-
- # NOTE: we must enable "ALL" ciphers on the client, otherwise an
- # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
- # starting from OpenSSL 1.0.0 (see issue #8322).
- if client_context.protocol == ssl.PROTOCOL_SSLv23:
- client_context.set_ciphers("ALL")
-
- for ctx in (client_context, server_context):
- ctx.verify_mode = certsreqs
- ctx.load_cert_chain(CERTFILE)
- ctx.load_verify_locations(CERTFILE)
- try:
- stats = server_params_test(client_context, server_context,
- chatty=False, connectionchatty=False)
- # Protocol mismatch can result in either an SSLError, or a
- # "Connection reset by peer" error.
- except ssl.SSLError:
- if expect_success:
- raise
- except OSError as e:
- if expect_success or e.errno != errno.ECONNRESET:
- raise
- else:
- if not expect_success:
- raise AssertionError(
- "Client protocol %s succeeded with server protocol %s!"
- % (ssl.get_protocol_name(client_protocol),
- ssl.get_protocol_name(server_protocol)))
- elif (expect_success is not True
- and expect_success != stats['version']):
- raise AssertionError("version mismatch: expected %r, got %r"
- % (expect_success, stats['version']))
-
-
- class ThreadedTests(unittest.TestCase):
-
- @skip_if_broken_ubuntu_ssl
- def test_echo(self):
- """Basic test of an SSL client connecting to a server"""
- if support.verbose:
- sys.stdout.write("\n")
- for protocol in PROTOCOLS:
- if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
- continue
- with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
- context = ssl.SSLContext(protocol)
- context.load_cert_chain(CERTFILE)
- server_params_test(context, context,
- chatty=True, connectionchatty=True)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- client_context.load_verify_locations(SIGNING_CA)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- # server_context.load_verify_locations(SIGNING_CA)
- server_context.load_cert_chain(SIGNED_CERTFILE2)
+class ThreadedTests(unittest.TestCase):
- with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
- server_params_test(client_context=client_context,
- server_context=server_context,
+ @skip_if_broken_ubuntu_ssl
+ def test_echo(self):
+ """Basic test of an SSL client connecting to a server"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ for protocol in PROTOCOLS:
+ if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
+ continue
+ with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
+ context = ssl.SSLContext(protocol)
+ context.load_cert_chain(CERTFILE)
+ server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_context.load_verify_locations(SIGNING_CA)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ # server_context.load_verify_locations(SIGNING_CA)
+ server_context.load_cert_chain(SIGNED_CERTFILE2)
+
+ with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
+ server_params_test(client_context=client_context,
+ server_context=server_context,
+ chatty=True, connectionchatty=True,
+ sni_name='fakehostname')
+
+ client_context.check_hostname = False
+ with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=client_context,
chatty=True, connectionchatty=True,
sni_name='fakehostname')
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- client_context.check_hostname = False
- with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=client_context,
- chatty=True, connectionchatty=True,
- sni_name='fakehostname')
- self.assertIn('called a function you should not call',
- str(e.exception))
-
- with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=server_context,
- chatty=True, connectionchatty=True)
- self.assertIn('called a function you should not call',
- str(e.exception))
+ with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=client_context,
- chatty=True, connectionchatty=True)
- self.assertIn('called a function you should not call',
- str(e.exception))
+ with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=client_context,
+ chatty=True, connectionchatty=True)
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- def test_getpeercert(self):
+ def test_getpeercert(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ s = context.wrap_socket(socket.socket(),
+ do_handshake_on_connect=False)
+ s.connect((HOST, server.port))
+ # getpeercert() raise ValueError while the handshake isn't
+ # done.
+ with self.assertRaises(ValueError):
+ s.getpeercert()
+ s.do_handshake()
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+ cipher = s.cipher()
if support.verbose:
- sys.stdout.write("\n")
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- s = context.wrap_socket(socket.socket(),
- do_handshake_on_connect=False)
+ sys.stdout.write(pprint.pformat(cert) + '\n')
+ sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
+ if 'subject' not in cert:
+ self.fail("No subject field in certificate: %s." %
+ pprint.pformat(cert))
+ if ((('organizationName', 'Python Software Foundation'),)
+ not in cert['subject']):
+ self.fail(
+ "Missing or invalid 'organizationName' field in certificate subject; "
+ "should be 'Python Software Foundation'.")
+ self.assertIn('notBefore', cert)
+ self.assertIn('notAfter', cert)
+ before = ssl.cert_time_to_seconds(cert['notBefore'])
+ after = ssl.cert_time_to_seconds(cert['notAfter'])
+ self.assertLess(before, after)
+ s.close()
+
+ @unittest.skipUnless(have_verify_flags(),
+ "verify_flags need OpenSSL > 0.9.8")
+ def test_crl_check(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(SIGNING_CA)
+ tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
+ self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
+
+ # VERIFY_DEFAULT should pass
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- # getpeercert() raise ValueError while the handshake isn't
- # done.
- with self.assertRaises(ValueError):
- s.getpeercert()
- s.do_handshake()
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
- cipher = s.cipher()
- if support.verbose:
- sys.stdout.write(pprint.pformat(cert) + '\n')
- sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
- if 'subject' not in cert:
- self.fail("No subject field in certificate: %s." %
- pprint.pformat(cert))
- if ((('organizationName', 'Python Software Foundation'),)
- not in cert['subject']):
- self.fail(
- "Missing or invalid 'organizationName' field in certificate subject; "
- "should be 'Python Software Foundation'.")
- self.assertIn('notBefore', cert)
- self.assertIn('notAfter', cert)
- before = ssl.cert_time_to_seconds(cert['notBefore'])
- after = ssl.cert_time_to_seconds(cert['notAfter'])
- self.assertLess(before, after)
- s.close()
- @unittest.skipUnless(have_verify_flags(),
- "verify_flags need OpenSSL > 0.9.8")
- def test_crl_check(self):
- if support.verbose:
- sys.stdout.write("\n")
+ # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
+ context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
-
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(SIGNING_CA)
- tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
- self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
-
- # VERIFY_DEFAULT should pass
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ with self.assertRaisesRegex(ssl.SSLError,
+ "certificate verify failed"):
s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
- # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
- context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
+ # now load a CRL file. The CRL file is signed by the CA.
+ context.load_verify_locations(CRLFILE)
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- with self.assertRaisesRegex(ssl.SSLError,
- "certificate verify failed"):
- s.connect((HOST, server.port))
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
- # now load a CRL file. The CRL file is signed by the CA.
- context.load_verify_locations(CRLFILE)
+ def test_check_hostname(self):
+ if support.verbose:
+ sys.stdout.write("\n")
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
- def test_check_hostname(self):
- if support.verbose:
- sys.stdout.write("\n")
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.check_hostname = True
+ context.load_verify_locations(SIGNING_CA)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
-
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.verify_mode = ssl.CERT_REQUIRED
- context.check_hostname = True
- context.load_verify_locations(SIGNING_CA)
-
- # correct hostname should verify
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket(),
- server_hostname="localhost") as s:
+ # correct hostname should verify
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="localhost") as s:
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+
+ # incorrect hostname should raise an exception
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="invalid") as s:
+ with self.assertRaisesRegex(ssl.CertificateError,
+ "hostname 'invalid' doesn't match 'localhost'"):
s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
-
- # incorrect hostname should raise an exception
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket(),
- server_hostname="invalid") as s:
- with self.assertRaisesRegex(ssl.CertificateError,
- "hostname 'invalid' doesn't match 'localhost'"):
- s.connect((HOST, server.port))
-
- # missing server_hostname arg should cause an exception, too
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with socket.socket() as s:
- with self.assertRaisesRegex(ValueError,
- "check_hostname requires server_hostname"):
- context.wrap_socket(s)
-
- def test_wrong_cert(self):
- """Connecting when the server rejects the client's certificate
-
- Launch a server with CERT_REQUIRED, and check that trying to
- connect to it with a wrong client certificate fails.
- """
- certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
- "wrongcert.pem")
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_REQUIRED,
- cacerts=CERTFILE, chatty=False,
- connectionchatty=False)
- with server, \
- socket.socket() as sock, \
- test_wrap_socket(sock,
- certfile=certfile,
- ssl_version=ssl.PROTOCOL_TLSv1) as s:
+
+ # missing server_hostname arg should cause an exception, too
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with socket.socket() as s:
+ with self.assertRaisesRegex(ValueError,
+ "check_hostname requires server_hostname"):
+ context.wrap_socket(s)
+
+ def test_wrong_cert(self):
+ """Connecting when the server rejects the client's certificate
+
+ Launch a server with CERT_REQUIRED, and check that trying to
+ connect to it with a wrong client certificate fails.
+ """
+ certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
+ "wrongcert.pem")
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_REQUIRED,
+ cacerts=CERTFILE, chatty=False,
+ connectionchatty=False)
+ with server, \
+ socket.socket() as sock, \
+ test_wrap_socket(sock,
+ certfile=certfile,
+ ssl_version=ssl.PROTOCOL_TLSv1) as s:
+ try:
+ # Expect either an SSL error about the server rejecting
+ # the connection, or a low-level connection reset (which
+ # sometimes happens on Windows)
+ s.connect((HOST, server.port))
+ except ssl.SSLError as e:
+ if support.verbose:
+ sys.stdout.write("\nSSLError is %r\n" % e)
+ except OSError as e:
+ if e.errno != errno.ECONNRESET:
+ raise
+ if support.verbose:
+ sys.stdout.write("\nsocket.error is %r\n" % e)
+ else:
+ self.fail("Use of invalid cert should have failed!")
+
+ def test_rude_shutdown(self):
+ """A brutal shutdown of an SSL server should raise an OSError
+ in the client when attempting handshake.
+ """
+ listener_ready = threading.Event()
+ listener_gone = threading.Event()
+
+ s = socket.socket()
+ port = support.bind_port(s, HOST)
+
+ # `listener` runs in a thread. It sits in an accept() until
+ # the main thread connects. Then it rudely closes the socket,
+ # and sets Event `listener_gone` to let the main thread know
+ # the socket is gone.
+ def listener():
+ s.listen()
+ listener_ready.set()
+ newsock, addr = s.accept()
+ newsock.close()
+ s.close()
+ listener_gone.set()
+
+ def connector():
+ listener_ready.wait()
+ with socket.socket() as c:
+ c.connect((HOST, port))
+ listener_gone.wait()
try:
- # Expect either an SSL error about the server rejecting
- # the connection, or a low-level connection reset (which
- # sometimes happens on Windows)
- s.connect((HOST, server.port))
- except ssl.SSLError as e:
- if support.verbose:
- sys.stdout.write("\nSSLError is %r\n" % e)
- except OSError as e:
- if e.errno != errno.ECONNRESET:
- raise
- if support.verbose:
- sys.stdout.write("\nsocket.error is %r\n" % e)
+ ssl_sock = test_wrap_socket(c)
+ except OSError:
+ pass
else:
- self.fail("Use of invalid cert should have failed!")
+ self.fail('connecting to closed SSL socket should have failed')
- def test_rude_shutdown(self):
- """A brutal shutdown of an SSL server should raise an OSError
- in the client when attempting handshake.
- """
- listener_ready = threading.Event()
- listener_gone = threading.Event()
+ t = threading.Thread(target=listener)
+ t.start()
+ try:
+ connector()
+ finally:
+ t.join()
- s = socket.socket()
- port = support.bind_port(s, HOST)
-
- # `listener` runs in a thread. It sits in an accept() until
- # the main thread connects. Then it rudely closes the socket,
- # and sets Event `listener_gone` to let the main thread know
- # the socket is gone.
- def listener():
- s.listen()
- listener_ready.set()
- newsock, addr = s.accept()
- newsock.close()
- s.close()
- listener_gone.set()
-
- def connector():
- listener_ready.wait()
- with socket.socket() as c:
- c.connect((HOST, port))
- listener_gone.wait()
- try:
- ssl_sock = test_wrap_socket(c)
- except OSError:
- pass
- else:
- self.fail('connecting to closed SSL socket should have failed')
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
+ "OpenSSL is compiled without SSLv2 support")
+ def test_protocol_sslv2(self):
+ """Connecting to an SSLv2 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+ # SSLv23 client with specific SSL options
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv2)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1)
- t = threading.Thread(target=listener)
- t.start()
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_sslv23(self):
+ """Connecting to an SSLv23 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
try:
- connector()
- finally:
- t.join()
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
+ except OSError as x:
+ # this fails on some older versions of OpenSSL (0.9.7l, for instance)
+ if support.verbose:
+ sys.stdout.write(
+ " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
+ % str(x))
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
+
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
+
+ # Server with specific SSL options
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
+ server_options=ssl.OP_NO_SSLv3)
+ # Will choose TLSv1
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
+ server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
+ server_options=ssl.OP_NO_TLSv1)
+
+
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
+ "OpenSSL is compiled without SSLv3 support")
+ def test_protocol_sslv3(self):
+ """Connecting to an SSLv3 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
+ False, client_options=ssl.OP_NO_SSLv2)
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_tlsv1(self):
+ """Connecting to a TLSv1 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1)
+
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
+ "TLS version 1.1 not supported.")
+ def test_protocol_tlsv1_1(self):
+ """Connecting to a TLSv1.1 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_1)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
- "OpenSSL is compiled without SSLv2 support")
- def test_protocol_sslv2(self):
- """Connecting to an SSLv2 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
- # SSLv23 client with specific SSL options
- if no_sslv2_implies_sslv3_hello():
- # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv2)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1)
- @skip_if_broken_ubuntu_ssl
- def test_protocol_sslv23(self):
- """Connecting to an SSLv23 server with various client options"""
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
+ "TLS version 1.2 not supported.")
+ def test_protocol_tlsv1_2(self):
+ """Connecting to a TLSv1.2 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
+ server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
+ client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_2)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
+
+ def test_starttls(self):
+ """Switching from clear text to encrypted and back again."""
+ msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
+
+ server = ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ starttls_server=True,
+ chatty=True,
+ connectionchatty=True)
+ wrapped = False
+ with server:
+ s = socket.socket()
+ s.setblocking(1)
+ s.connect((HOST, server.port))
if support.verbose:
sys.stdout.write("\n")
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try:
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
- except OSError as x:
- # this fails on some older versions of OpenSSL (0.9.7l, for instance)
+ for indata in msgs:
+ if support.verbose:
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ if wrapped:
+ conn.write(indata)
+ outdata = conn.read()
+ else:
+ s.send(indata)
+ outdata = s.recv(1024)
+ msg = outdata.strip().lower()
+ if indata == b"STARTTLS" and msg.startswith(b"ok"):
+ # STARTTLS ok, switch to secure mode
if support.verbose:
sys.stdout.write(
- " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
- % str(x))
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
-
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
-
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
-
- # Server with specific SSL options
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
- server_options=ssl.OP_NO_SSLv3)
- # Will choose TLSv1
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
- server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
- server_options=ssl.OP_NO_TLSv1)
-
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
- "OpenSSL is compiled without SSLv3 support")
- def test_protocol_sslv3(self):
- """Connecting to an SSLv3 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
- if no_sslv2_implies_sslv3_hello():
- # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
- False, client_options=ssl.OP_NO_SSLv2)
-
- @skip_if_broken_ubuntu_ssl
- def test_protocol_tlsv1(self):
- """Connecting to a TLSv1 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1)
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
- "TLS version 1.1 not supported.")
- def test_protocol_tlsv1_1(self):
- """Connecting to a TLSv1.1 server with various client options.
- Testing against older TLS versions."""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1_1)
-
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
-
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
- "TLS version 1.2 not supported.")
- def test_protocol_tlsv1_2(self):
- """Connecting to a TLSv1.2 server with various client options.
- Testing against older TLS versions."""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
- server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
- client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1_2)
-
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
-
- def test_starttls(self):
- """Switching from clear text to encrypted and back again."""
- msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
-
- server = ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- starttls_server=True,
- chatty=True,
- connectionchatty=True)
- wrapped = False
- with server:
- s = socket.socket()
- s.setblocking(1)
- s.connect((HOST, server.port))
- if support.verbose:
- sys.stdout.write("\n")
- for indata in msgs:
+ " client: read %r from server, starting TLS...\n"
+ % msg)
+ conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
+ wrapped = True
+ elif indata == b"ENDTLS" and msg.startswith(b"ok"):
+ # ENDTLS ok, switch back to clear text
if support.verbose:
sys.stdout.write(
- " client: sending %r...\n" % indata)
- if wrapped:
- conn.write(indata)
- outdata = conn.read()
- else:
- s.send(indata)
- outdata = s.recv(1024)
- msg = outdata.strip().lower()
- if indata == b"STARTTLS" and msg.startswith(b"ok"):
- # STARTTLS ok, switch to secure mode
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server, starting TLS...\n"
- % msg)
- conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
- wrapped = True
- elif indata == b"ENDTLS" and msg.startswith(b"ok"):
- # ENDTLS ok, switch back to clear text
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server, ending TLS...\n"
- % msg)
- s = conn.unwrap()
- wrapped = False
- else:
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server\n" % msg)
- if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- if wrapped:
- conn.write(b"over\n")
+ " client: read %r from server, ending TLS...\n"
+ % msg)
+ s = conn.unwrap()
+ wrapped = False
else:
- s.send(b"over\n")
- if wrapped:
- conn.close()
- else:
- s.close()
-
- def test_socketserver(self):
- """Using socketserver to create and manage SSL connections."""
- server = make_https_server(self, certfile=CERTFILE)
- # try to connect
- if support.verbose:
- sys.stdout.write('\n')
- with open(CERTFILE, 'rb') as f:
- d1 = f.read()
- d2 = ''
- # now fetch the same data from the HTTPS server
- url = 'https://localhost:%d/%s' % (
- server.port, os.path.split(CERTFILE)[1])
- context = ssl.create_default_context(cafile=CERTFILE)
- f = urllib.request.urlopen(url, context=context)
- try:
- dlen = f.info().get("content-length")
- if dlen and (int(dlen) > 0):
- d2 = f.read(int(dlen))
if support.verbose:
sys.stdout.write(
- " client: read %d bytes from remote server '%s'\n"
- % (len(d2), server))
- finally:
- f.close()
- self.assertEqual(d1, d2)
-
- def test_asyncore_server(self):
- """Check the example asyncore integration."""
+ " client: read %r from server\n" % msg)
if support.verbose:
- sys.stdout.write("\n")
+ sys.stdout.write(" client: closing connection.\n")
+ if wrapped:
+ conn.write(b"over\n")
+ else:
+ s.send(b"over\n")
+ if wrapped:
+ conn.close()
+ else:
+ s.close()
- indata = b"FOO\n"
- server = AsyncoreEchoServer(CERTFILE)
- with server:
- s = test_wrap_socket(socket.socket())
- s.connect(('127.0.0.1', server.port))
+ def test_socketserver(self):
+ """Using socketserver to create and manage SSL connections."""
+ server = make_https_server(self, certfile=CERTFILE)
+ # try to connect
+ if support.verbose:
+ sys.stdout.write('\n')
+ with open(CERTFILE, 'rb') as f:
+ d1 = f.read()
+ d2 = ''
+ # now fetch the same data from the HTTPS server
+ url = 'https://localhost:%d/%s' % (
+ server.port, os.path.split(CERTFILE)[1])
+ context = ssl.create_default_context(cafile=CERTFILE)
+ f = urllib.request.urlopen(url, context=context)
+ try:
+ dlen = f.info().get("content-length")
+ if dlen and (int(dlen) > 0):
+ d2 = f.read(int(dlen))
if support.verbose:
sys.stdout.write(
- " client: sending %r...\n" % indata)
- s.write(indata)
- outdata = s.read()
- if support.verbose:
- sys.stdout.write(" client: read %r\n" % outdata)
- if outdata != indata.lower():
- self.fail(
- "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
- % (outdata[:20], len(outdata),
- indata[:20].lower(), len(indata)))
- s.write(b"over\n")
- if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- s.close()
- if support.verbose:
- sys.stdout.write(" client: connection closed.\n")
+ " client: read %d bytes from remote server '%s'\n"
+ % (len(d2), server))
+ finally:
+ f.close()
+ self.assertEqual(d1, d2)
- def test_recv_send(self):
- """Test recv(), send() and friends."""
+ def test_asyncore_server(self):
+ """Check the example asyncore integration."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ indata = b"FOO\n"
+ server = AsyncoreEchoServer(CERTFILE)
+ with server:
+ s = test_wrap_socket(socket.socket())
+ s.connect(('127.0.0.1', server.port))
if support.verbose:
- sys.stdout.write("\n")
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ s.write(indata)
+ outdata = s.read()
+ if support.verbose:
+ sys.stdout.write(" client: read %r\n" % outdata)
+ if outdata != indata.lower():
+ self.fail(
+ "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+ % (outdata[:20], len(outdata),
+ indata[:20].lower(), len(indata)))
+ s.write(b"over\n")
+ if support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+ if support.verbose:
+ sys.stdout.write(" client: connection closed.\n")
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- # helper methods for standardising recv* method signatures
- def _recv_into():
- b = bytearray(b"\0"*100)
- count = s.recv_into(b)
- return b[:count]
-
- def _recvfrom_into():
- b = bytearray(b"\0"*100)
- count, addr = s.recvfrom_into(b)
- return b[:count]
-
- # (name, method, expect success?, *args, return value func)
- send_methods = [
- ('send', s.send, True, [], len),
- ('sendto', s.sendto, False, ["some.address"], len),
- ('sendall', s.sendall, True, [], lambda x: None),
- ]
- # (name, method, whether to expect success, *args)
- recv_methods = [
- ('recv', s.recv, True, []),
- ('recvfrom', s.recvfrom, False, ["some.address"]),
- ('recv_into', _recv_into, True, []),
- ('recvfrom_into', _recvfrom_into, False, []),
- ]
- data_prefix = "PREFIX_"
-
- for (meth_name, send_meth, expect_success, args,
- ret_val_meth) in send_methods:
- indata = (data_prefix + meth_name).encode('ascii')
- try:
- ret = send_meth(indata, *args)
- msg = "sending with {}".format(meth_name)
- self.assertEqual(ret, ret_val_meth(indata), msg=msg)
- outdata = s.read()
- if outdata != indata.lower():
- self.fail(
- "While sending with <<{name:s}>> bad data "
- "<<{outdata:r}>> ({nout:d}) received; "
- "expected <<{indata:r}>> ({nin:d})\n".format(
- name=meth_name, outdata=outdata[:20],
- nout=len(outdata),
- indata=indata[:20], nin=len(indata)
- )
- )
- except ValueError as e:
- if expect_success:
- self.fail(
- "Failed to send with method <<{name:s}>>; "
- "expected to succeed.\n".format(name=meth_name)
+ def test_recv_send(self):
+ """Test recv(), send() and friends."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # helper methods for standardising recv* method signatures
+ def _recv_into():
+ b = bytearray(b"\0"*100)
+ count = s.recv_into(b)
+ return b[:count]
+
+ def _recvfrom_into():
+ b = bytearray(b"\0"*100)
+ count, addr = s.recvfrom_into(b)
+ return b[:count]
+
+ # (name, method, expect success?, *args, return value func)
+ send_methods = [
+ ('send', s.send, True, [], len),
+ ('sendto', s.sendto, False, ["some.address"], len),
+ ('sendall', s.sendall, True, [], lambda x: None),
+ ]
+ # (name, method, whether to expect success, *args)
+ recv_methods = [
+ ('recv', s.recv, True, []),
+ ('recvfrom', s.recvfrom, False, ["some.address"]),
+ ('recv_into', _recv_into, True, []),
+ ('recvfrom_into', _recvfrom_into, False, []),
+ ]
+ data_prefix = "PREFIX_"
+
+ for (meth_name, send_meth, expect_success, args,
+ ret_val_meth) in send_methods:
+ indata = (data_prefix + meth_name).encode('ascii')
+ try:
+ ret = send_meth(indata, *args)
+ msg = "sending with {}".format(meth_name)
+ self.assertEqual(ret, ret_val_meth(indata), msg=msg)
+ outdata = s.read()
+ if outdata != indata.lower():
+ self.fail(
+ "While sending with <<{name:s}>> bad data "
+ "<<{outdata:r}>> ({nout:d}) received; "
+ "expected <<{indata:r}>> ({nin:d})\n".format(
+ name=meth_name, outdata=outdata[:20],
+ nout=len(outdata),
+ indata=indata[:20], nin=len(indata)
)
- if not str(e).startswith(meth_name):
- self.fail(
- "Method <<{name:s}>> failed with unexpected "
- "exception message: {exp:s}\n".format(
- name=meth_name, exp=e
- )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to send with method <<{name:s}>>; "
+ "expected to succeed.\n".format(name=meth_name)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<{name:s}>> failed with unexpected "
+ "exception message: {exp:s}\n".format(
+ name=meth_name, exp=e
)
+ )
- for meth_name, recv_meth, expect_success, args in recv_methods:
- indata = (data_prefix + meth_name).encode('ascii')
- try:
- s.send(indata)
- outdata = recv_meth(*args)
- if outdata != indata.lower():
- self.fail(
- "While receiving with <<{name:s}>> bad data "
- "<<{outdata:r}>> ({nout:d}) received; "
- "expected <<{indata:r}>> ({nin:d})\n".format(
- name=meth_name, outdata=outdata[:20],
- nout=len(outdata),
- indata=indata[:20], nin=len(indata)
- )
- )
- except ValueError as e:
- if expect_success:
- self.fail(
- "Failed to receive with method <<{name:s}>>; "
- "expected to succeed.\n".format(name=meth_name)
+ for meth_name, recv_meth, expect_success, args in recv_methods:
+ indata = (data_prefix + meth_name).encode('ascii')
+ try:
+ s.send(indata)
+ outdata = recv_meth(*args)
+ if outdata != indata.lower():
+ self.fail(
+ "While receiving with <<{name:s}>> bad data "
+ "<<{outdata:r}>> ({nout:d}) received; "
+ "expected <<{indata:r}>> ({nin:d})\n".format(
+ name=meth_name, outdata=outdata[:20],
+ nout=len(outdata),
+ indata=indata[:20], nin=len(indata)
)
- if not str(e).startswith(meth_name):
- self.fail(
- "Method <<{name:s}>> failed with unexpected "
- "exception message: {exp:s}\n".format(
- name=meth_name, exp=e
- )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to receive with method <<{name:s}>>; "
+ "expected to succeed.\n".format(name=meth_name)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<{name:s}>> failed with unexpected "
+ "exception message: {exp:s}\n".format(
+ name=meth_name, exp=e
)
- # consume data
- s.read()
+ )
+ # consume data
+ s.read()
- # read(-1, buffer) is supported, even though read(-1) is not
- data = b"data"
- s.send(data)
- buffer = bytearray(len(data))
- self.assertEqual(s.read(-1, buffer), len(data))
- self.assertEqual(buffer, data)
+ # read(-1, buffer) is supported, even though read(-1) is not
+ data = b"data"
+ s.send(data)
+ buffer = bytearray(len(data))
+ self.assertEqual(s.read(-1, buffer), len(data))
+ self.assertEqual(buffer, data)
- # Make sure sendmsg et al are disallowed to avoid
- # inadvertent disclosure of data and/or corruption
- # of the encrypted data stream
- self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
- self.assertRaises(NotImplementedError, s.recvmsg, 100)
- self.assertRaises(NotImplementedError,
- s.recvmsg_into, bytearray(100))
+ # Make sure sendmsg et al are disallowed to avoid
+ # inadvertent disclosure of data and/or corruption
+ # of the encrypted data stream
+ self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
+ self.assertRaises(NotImplementedError, s.recvmsg, 100)
+ self.assertRaises(NotImplementedError,
+ s.recvmsg_into, bytearray(100))
- s.write(b"over\n")
+ s.write(b"over\n")
- self.assertRaises(ValueError, s.recv, -1)
- self.assertRaises(ValueError, s.read, -1)
+ self.assertRaises(ValueError, s.recv, -1)
+ self.assertRaises(ValueError, s.read, -1)
- s.close()
+ s.close()
- def test_recv_zero(self):
- server = ThreadedEchoServer(CERTFILE)
- server.__enter__()
- self.addCleanup(server.__exit__, None, None)
- s = socket.create_connection((HOST, server.port))
- self.addCleanup(s.close)
- s = test_wrap_socket(s, suppress_ragged_eofs=False)
- self.addCleanup(s.close)
+ def test_recv_zero(self):
+ server = ThreadedEchoServer(CERTFILE)
+ server.__enter__()
+ self.addCleanup(server.__exit__, None, None)
+ s = socket.create_connection((HOST, server.port))
+ self.addCleanup(s.close)
+ s = test_wrap_socket(s, suppress_ragged_eofs=False)
+ self.addCleanup(s.close)
- # recv/read(0) should return no data
- s.send(b"data")
- self.assertEqual(s.recv(0), b"")
- self.assertEqual(s.read(0), b"")
- self.assertEqual(s.read(), b"data")
+ # recv/read(0) should return no data
+ s.send(b"data")
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.read(0), b"")
+ self.assertEqual(s.read(), b"data")
+
+ # Should not block if the other end sends no data
+ s.setblocking(False)
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.recv_into(bytearray()), 0)
- # Should not block if the other end sends no data
+ def test_nonblocking_send(self):
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
s.setblocking(False)
- self.assertEqual(s.recv(0), b"")
- self.assertEqual(s.recv_into(bytearray()), 0)
-
- def test_nonblocking_send(self):
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- s.setblocking(False)
-
- # If we keep sending data, at some point the buffers
- # will be full and the call will block
- buf = bytearray(8192)
- def fill_buffer():
- while True:
- s.send(buf)
- self.assertRaises((ssl.SSLWantWriteError,
- ssl.SSLWantReadError), fill_buffer)
-
- # Now read all the output and discard it
- s.setblocking(True)
- s.close()
- def test_handshake_timeout(self):
- # Issue #5103: SSL handshake must respect the socket timeout
- server = socket.socket(socket.AF_INET)
- host = "127.0.0.1"
- port = support.bind_port(server)
- started = threading.Event()
- finish = False
-
- def serve():
- server.listen()
- started.set()
- conns = []
- while not finish:
- r, w, e = select.select([server], [], [], 0.1)
- if server in r:
- # Let the socket hang around rather than having
- # it closed by garbage collection.
- conns.append(server.accept()[0])
- for sock in conns:
- sock.close()
-
- t = threading.Thread(target=serve)
- t.start()
- started.wait()
+ # If we keep sending data, at some point the buffers
+ # will be full and the call will block
+ buf = bytearray(8192)
+ def fill_buffer():
+ while True:
+ s.send(buf)
+ self.assertRaises((ssl.SSLWantWriteError,
+ ssl.SSLWantReadError), fill_buffer)
+
+ # Now read all the output and discard it
+ s.setblocking(True)
+ s.close()
+
+ def test_handshake_timeout(self):
+ # Issue #5103: SSL handshake must respect the socket timeout
+ server = socket.socket(socket.AF_INET)
+ host = "127.0.0.1"
+ port = support.bind_port(server)
+ started = threading.Event()
+ finish = False
+
+ def serve():
+ server.listen()
+ started.set()
+ conns = []
+ while not finish:
+ r, w, e = select.select([server], [], [], 0.1)
+ if server in r:
+ # Let the socket hang around rather than having
+ # it closed by garbage collection.
+ conns.append(server.accept()[0])
+ for sock in conns:
+ sock.close()
+
+ t = threading.Thread(target=serve)
+ t.start()
+ started.wait()
+ try:
try:
- try:
- c = socket.socket(socket.AF_INET)
- c.settimeout(0.2)
- c.connect((host, port))
- # Will attempt handshake and time out
- self.assertRaisesRegex(socket.timeout, "timed out",
- test_wrap_socket, c)
- finally:
- c.close()
- try:
- c = socket.socket(socket.AF_INET)
- c = test_wrap_socket(c)
- c.settimeout(0.2)
- # Will attempt handshake and time out
- self.assertRaisesRegex(socket.timeout, "timed out",
- c.connect, (host, port))
- finally:
- c.close()
+ c = socket.socket(socket.AF_INET)
+ c.settimeout(0.2)
+ c.connect((host, port))
+ # Will attempt handshake and time out
+ self.assertRaisesRegex(socket.timeout, "timed out",
+ test_wrap_socket, c)
finally:
- finish = True
- t.join()
- server.close()
-
- def test_server_accept(self):
- # Issue #16357: accept() on a SSLSocket created through
- # SSLContext.wrap_socket().
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = socket.socket(socket.AF_INET)
- host = "127.0.0.1"
- port = support.bind_port(server)
- server = context.wrap_socket(server, server_side=True)
- self.assertTrue(server.server_side)
-
- evt = threading.Event()
- remote = None
- peer = None
- def serve():
- nonlocal remote, peer
- server.listen()
- # Block on the accept and wait on the connection to close.
- evt.set()
- remote, peer = server.accept()
- remote.recv(1)
-
- t = threading.Thread(target=serve)
- t.start()
- # Client wait until server setup and perform a connect.
- evt.wait()
- client = context.wrap_socket(socket.socket())
- client.connect((host, port))
- client_addr = client.getsockname()
- client.close()
+ c.close()
+ try:
+ c = socket.socket(socket.AF_INET)
+ c = test_wrap_socket(c)
+ c.settimeout(0.2)
+ # Will attempt handshake and time out
+ self.assertRaisesRegex(socket.timeout, "timed out",
+ c.connect, (host, port))
+ finally:
+ c.close()
+ finally:
+ finish = True
t.join()
- remote.close()
server.close()
- # Sanity checks.
- self.assertIsInstance(remote, ssl.SSLSocket)
- self.assertEqual(peer, client_addr)
-
- def test_getpeercert_enotconn(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- with context.wrap_socket(socket.socket()) as sock:
- with self.assertRaises(OSError) as cm:
- sock.getpeercert()
- self.assertEqual(cm.exception.errno, errno.ENOTCONN)
-
- def test_do_handshake_enotconn(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- with context.wrap_socket(socket.socket()) as sock:
- with self.assertRaises(OSError) as cm:
- sock.do_handshake()
- self.assertEqual(cm.exception.errno, errno.ENOTCONN)
-
- def test_default_ciphers(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- try:
- # Force a set of weak ciphers on our client context
- context.set_ciphers("DES")
- except ssl.SSLError:
- self.skipTest("no DES cipher available")
- with ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_SSLv23,
- chatty=False) as server:
- with context.wrap_socket(socket.socket()) as s:
- with self.assertRaises(OSError):
- s.connect((HOST, server.port))
- self.assertIn("no shared cipher", server.conn_errors[0])
-
- def test_version_basic(self):
- """
- Basic tests for SSLSocket.version().
- More tests are done in the test_protocol_*() methods.
- """
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- with ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- chatty=False) as server:
- with context.wrap_socket(socket.socket()) as s:
- self.assertIs(s.version(), None)
- s.connect((HOST, server.port))
- self.assertEqual(s.version(), 'TLSv1')
- self.assertIs(s.version(), None)
- @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
- def test_default_ecdh_curve(self):
- # Issue #21015: elliptic curve-based Diffie Hellman key exchange
- # should be enabled by default on SSL contexts.
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.load_cert_chain(CERTFILE)
- # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
- # explicitly using the 'ECCdraft' cipher alias. Otherwise,
- # our default cipher list should prefer ECDH-based ciphers
- # automatically.
- if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
- context.set_ciphers("ECCdraft:ECDH")
- with ThreadedEchoServer(context=context) as server:
- with context.wrap_socket(socket.socket()) as s:
+ def test_server_accept(self):
+ # Issue #16357: accept() on a SSLSocket created through
+ # SSLContext.wrap_socket().
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = socket.socket(socket.AF_INET)
+ host = "127.0.0.1"
+ port = support.bind_port(server)
+ server = context.wrap_socket(server, server_side=True)
+ self.assertTrue(server.server_side)
+
+ evt = threading.Event()
+ remote = None
+ peer = None
+ def serve():
+ nonlocal remote, peer
+ server.listen()
+ # Block on the accept and wait on the connection to close.
+ evt.set()
+ remote, peer = server.accept()
+ remote.recv(1)
+
+ t = threading.Thread(target=serve)
+ t.start()
+ # Client wait until server setup and perform a connect.
+ evt.wait()
+ client = context.wrap_socket(socket.socket())
+ client.connect((host, port))
+ client_addr = client.getsockname()
+ client.close()
+ t.join()
+ remote.close()
+ server.close()
+ # Sanity checks.
+ self.assertIsInstance(remote, ssl.SSLSocket)
+ self.assertEqual(peer, client_addr)
+
+ def test_getpeercert_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.getpeercert()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
+ def test_do_handshake_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.do_handshake()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
+ def test_default_ciphers(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ try:
+ # Force a set of weak ciphers on our client context
+ context.set_ciphers("DES")
+ except ssl.SSLError:
+ self.skipTest("no DES cipher available")
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ with self.assertRaises(OSError):
s.connect((HOST, server.port))
- self.assertIn("ECDH", s.cipher()[0])
+ self.assertIn("no shared cipher", server.conn_errors[0])
- @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
- "'tls-unique' channel binding not available")
- def test_tls_unique_channel_binding(self):
- """Test tls-unique channel binding."""
- if support.verbose:
- sys.stdout.write("\n")
-
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ def test_version_basic(self):
+ """
+ Basic tests for SSLSocket.version().
+ More tests are done in the test_protocol_*() methods.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ self.assertIs(s.version(), None)
s.connect((HOST, server.port))
- # get the data
- cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got channel binding data: {0!r}\n"
- .format(cb_data))
-
- # check if it is sane
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
-
- # and compare with the peers version
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(cb_data).encode("us-ascii"))
- s.close()
-
- # now, again
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ self.assertEqual(s.version(), 'TLSv1')
+ self.assertIs(s.version(), None)
+
+ @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
+ def test_default_ecdh_curve(self):
+ # Issue #21015: elliptic curve-based Diffie Hellman key exchange
+ # should be enabled by default on SSL contexts.
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(CERTFILE)
+ # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
+ # explicitly using the 'ECCdraft' cipher alias. Otherwise,
+ # our default cipher list should prefer ECDH-based ciphers
+ # automatically.
+ if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
+ context.set_ciphers("ECCdraft:ECDH")
+ with ThreadedEchoServer(context=context) as server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- new_cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got another channel binding data: {0!r}\n"
- .format(new_cb_data))
- # is it really unique
- self.assertNotEqual(cb_data, new_cb_data)
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(new_cb_data).encode("us-ascii"))
- s.close()
+ self.assertIn("ECDH", s.cipher()[0])
- def test_compression(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- if support.verbose:
- sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
- self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
-
- @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
- "ssl.OP_NO_COMPRESSION needed for this test")
- def test_compression_disabled(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.options |= ssl.OP_NO_COMPRESSION
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['compression'], None)
-
- def test_dh_params(self):
- # Check we can get a connection with ephemeral Diffie-Hellman
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.load_dh_params(DHFILE)
- context.set_ciphers("kEDH")
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- cipher = stats["cipher"][0]
- parts = cipher.split("-")
- if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
- self.fail("Non-DH cipher: " + cipher[0])
-
- def test_selected_alpn_protocol(self):
- # selected_alpn_protocol() is None unless ALPN is used.
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_alpn_protocol'], None)
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ """Test tls-unique channel binding."""
+ if support.verbose:
+ sys.stdout.write("\n")
- @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
- def test_selected_alpn_protocol_if_server_uses_alpn(self):
- # selected_alpn_protocol() is None unless ALPN is used by the client.
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.load_verify_locations(CERTFILE)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # get the data
+ cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got channel binding data: {0!r}\n"
+ .format(cb_data))
+
+ # check if it is sane
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+
+ # and compare with the peers version
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(cb_data).encode("us-ascii"))
+ s.close()
+
+ # now, again
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ new_cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got another channel binding data: {0!r}\n"
+ .format(new_cb_data))
+ # is it really unique
+ self.assertNotEqual(cb_data, new_cb_data)
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(new_cb_data).encode("us-ascii"))
+ s.close()
+
+ def test_compression(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ if support.verbose:
+ sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
+ self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
+
+ @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
+ "ssl.OP_NO_COMPRESSION needed for this test")
+ def test_compression_disabled(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.options |= ssl.OP_NO_COMPRESSION
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['compression'], None)
+
+ def test_dh_params(self):
+ # Check we can get a connection with ephemeral Diffie-Hellman
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.load_dh_params(DHFILE)
+ context.set_ciphers("kEDH")
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ cipher = stats["cipher"][0]
+ parts = cipher.split("-")
+ if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
+ self.fail("Non-DH cipher: " + cipher[0])
+
+ def test_selected_alpn_protocol(self):
+ # selected_alpn_protocol() is None unless ALPN is used.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+ def test_selected_alpn_protocol_if_server_uses_alpn(self):
+ # selected_alpn_protocol() is None unless ALPN is used by the client.
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_verify_locations(CERTFILE)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(['foo', 'bar'])
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+ def test_alpn_protocols(self):
+ server_protocols = ['foo', 'bar', 'milkshake']
+ protocol_tests = [
+ (['foo', 'bar'], 'foo'),
+ (['bar', 'foo'], 'foo'),
+ (['milkshake'], 'milkshake'),
+ (['http/3.0', 'http/4.0'], None)
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
server_context.load_cert_chain(CERTFILE)
- server_context.set_alpn_protocols(['foo', 'bar'])
- stats = server_params_test(client_context, server_context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_alpn_protocol'], None)
-
- @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
- def test_alpn_protocols(self):
- server_protocols = ['foo', 'bar', 'milkshake']
- protocol_tests = [
- (['foo', 'bar'], 'foo'),
- (['bar', 'foo'], 'foo'),
- (['milkshake'], 'milkshake'),
- (['http/3.0', 'http/4.0'], None)
- ]
- for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
- server_context.load_cert_chain(CERTFILE)
- server_context.set_alpn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
- client_context.load_cert_chain(CERTFILE)
- client_context.set_alpn_protocols(client_protocols)
+ server_context.set_alpn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_alpn_protocols(client_protocols)
- try:
- stats = server_params_test(client_context,
- server_context,
- chatty=True,
- connectionchatty=True)
- except ssl.SSLError as e:
- stats = e
-
- if (expected is None and IS_OPENSSL_1_1
- and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
- # OpenSSL 1.1.0 to 1.1.0e raises handshake error
- self.assertIsInstance(stats, ssl.SSLError)
- else:
- msg = "failed trying %s (s) and %s (c).\n" \
- "was expecting %s, but got %%s from the %%s" \
- % (str(server_protocols), str(client_protocols),
- str(expected))
- client_result = stats['client_alpn_protocol']
- self.assertEqual(client_result, expected,
- msg % (client_result, "client"))
- server_result = stats['server_alpn_protocols'][-1] \
- if len(stats['server_alpn_protocols']) else 'nothing'
- self.assertEqual(server_result, expected,
- msg % (server_result, "server"))
-
- def test_selected_npn_protocol(self):
- # selected_npn_protocol() is None unless NPN is used
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_npn_protocol'], None)
-
- @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
- def test_npn_protocols(self):
- server_protocols = ['http/1.1', 'spdy/2']
- protocol_tests = [
- (['http/1.1', 'spdy/2'], 'http/1.1'),
- (['spdy/2', 'http/1.1'], 'http/1.1'),
- (['spdy/2', 'test'], 'spdy/2'),
- (['abc', 'def'], 'abc')
- ]
- for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(CERTFILE)
- server_context.set_npn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.load_cert_chain(CERTFILE)
- client_context.set_npn_protocols(client_protocols)
- stats = server_params_test(client_context, server_context,
- chatty=True, connectionchatty=True)
+ try:
+ stats = server_params_test(client_context,
+ server_context,
+ chatty=True,
+ connectionchatty=True)
+ except ssl.SSLError as e:
+ stats = e
+ if (expected is None and IS_OPENSSL_1_1
+ and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
+ # OpenSSL 1.1.0 to 1.1.0e raises handshake error
+ self.assertIsInstance(stats, ssl.SSLError)
+ else:
msg = "failed trying %s (s) and %s (c).\n" \
- "was expecting %s, but got %%s from the %%s" \
- % (str(server_protocols), str(client_protocols),
- str(expected))
- client_result = stats['client_npn_protocol']
- self.assertEqual(client_result, expected, msg % (client_result, "client"))
- server_result = stats['server_npn_protocols'][-1] \
- if len(stats['server_npn_protocols']) else 'nothing'
- self.assertEqual(server_result, expected, msg % (server_result, "server"))
-
- def sni_contexts(self):
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_alpn_protocol']
+ self.assertEqual(client_result, expected,
+ msg % (client_result, "client"))
+ server_result = stats['server_alpn_protocols'][-1] \
+ if len(stats['server_alpn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected,
+ msg % (server_result, "server"))
+
+ def test_selected_npn_protocol(self):
+ # selected_npn_protocol() is None unless NPN is used
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_npn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
+ def test_npn_protocols(self):
+ server_protocols = ['http/1.1', 'spdy/2']
+ protocol_tests = [
+ (['http/1.1', 'spdy/2'], 'http/1.1'),
+ (['spdy/2', 'http/1.1'], 'http/1.1'),
+ (['spdy/2', 'test'], 'spdy/2'),
+ (['abc', 'def'], 'abc')
+ ]
+ for client_protocols, expected in protocol_tests:
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- other_context.load_cert_chain(SIGNED_CERTFILE2)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_npn_protocols(server_protocols)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
- return server_context, other_context, client_context
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_npn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
- def check_common_name(self, stats, name):
- cert = stats['peercert']
- self.assertIn((('commonName', name),), cert['subject'])
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_npn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_npn_protocols'][-1] \
+ if len(stats['server_npn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
+ def sni_contexts(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ other_context.load_cert_chain(SIGNED_CERTFILE2)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ return server_context, other_context, client_context
+
+ def check_common_name(self, stats, name):
+ cert = stats['peercert']
+ self.assertIn((('commonName', name),), cert['subject'])
+
+ @needs_sni
+ def test_sni_callback(self):
+ calls = []
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def servername_cb(ssl_sock, server_name, initial_context):
+ calls.append((server_name, initial_context))
+ if server_name is not None:
+ ssl_sock.context = other_context
+ server_context.set_servername_callback(servername_cb)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='supermessage')
+ # The hostname was fetched properly, and the certificate was
+ # changed for the connection.
+ self.assertEqual(calls, [("supermessage", server_context)])
+ # CERTFILE4 was selected
+ self.check_common_name(stats, 'fakehostname')
+
+ calls = []
+ # The callback is called with server_name=None
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name=None)
+ self.assertEqual(calls, [(None, server_context)])
+ self.check_common_name(stats, 'localhost')
+
+ # Check disabling the callback
+ calls = []
+ server_context.set_servername_callback(None)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='notfunny')
+ # Certificate didn't change
+ self.check_common_name(stats, 'localhost')
+ self.assertEqual(calls, [])
- @needs_sni
- def test_sni_callback(self):
- calls = []
- server_context, other_context, client_context = self.sni_contexts()
+ @needs_sni
+ def test_sni_callback_alert(self):
+ # Returning a TLS alert is reflected to the connecting client
+ server_context, other_context, client_context = self.sni_contexts()
- def servername_cb(ssl_sock, server_name, initial_context):
- calls.append((server_name, initial_context))
- if server_name is not None:
- ssl_sock.context = other_context
- server_context.set_servername_callback(servername_cb)
+ def cb_returning_alert(ssl_sock, server_name, initial_context):
+ return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
+ server_context.set_servername_callback(cb_returning_alert)
+ with self.assertRaises(ssl.SSLError) as cm:
stats = server_params_test(client_context, server_context,
- chatty=True,
+ chatty=False,
sni_name='supermessage')
- # The hostname was fetched properly, and the certificate was
- # changed for the connection.
- self.assertEqual(calls, [("supermessage", server_context)])
- # CERTFILE4 was selected
- self.check_common_name(stats, 'fakehostname')
-
- calls = []
- # The callback is called with server_name=None
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
+
+ @needs_sni
+ def test_sni_callback_raising(self):
+ # Raising fails the connection with a TLS handshake failure alert.
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def cb_raising(ssl_sock, server_name, initial_context):
+ 1/0
+ server_context.set_servername_callback(cb_raising)
+
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
stats = server_params_test(client_context, server_context,
- chatty=True,
- sni_name=None)
- self.assertEqual(calls, [(None, server_context)])
- self.check_common_name(stats, 'localhost')
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
+ self.assertIn("ZeroDivisionError", stderr.getvalue())
+
+ @needs_sni
+ def test_sni_callback_wrong_return_type(self):
+ # Returning the wrong return type terminates the TLS connection
+ # with an internal error alert.
+ server_context, other_context, client_context = self.sni_contexts()
- # Check disabling the callback
- calls = []
- server_context.set_servername_callback(None)
+ def cb_wrong_return_type(ssl_sock, server_name, initial_context):
+ return "foo"
+ server_context.set_servername_callback(cb_wrong_return_type)
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
stats = server_params_test(client_context, server_context,
- chatty=True,
- sni_name='notfunny')
- # Certificate didn't change
- self.check_common_name(stats, 'localhost')
- self.assertEqual(calls, [])
-
- @needs_sni
- def test_sni_callback_alert(self):
- # Returning a TLS alert is reflected to the connecting client
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_returning_alert(ssl_sock, server_name, initial_context):
- return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
- server_context.set_servername_callback(cb_returning_alert)
-
- with self.assertRaises(ssl.SSLError) as cm:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
-
- @needs_sni
- def test_sni_callback_raising(self):
- # Raising fails the connection with a TLS handshake failure alert.
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_raising(ssl_sock, server_name, initial_context):
- 1/0
- server_context.set_servername_callback(cb_raising)
-
- with self.assertRaises(ssl.SSLError) as cm, \
- support.captured_stderr() as stderr:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
- self.assertIn("ZeroDivisionError", stderr.getvalue())
-
- @needs_sni
- def test_sni_callback_wrong_return_type(self):
- # Returning the wrong return type terminates the TLS connection
- # with an internal error alert.
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_wrong_return_type(ssl_sock, server_name, initial_context):
- return "foo"
- server_context.set_servername_callback(cb_wrong_return_type)
-
- with self.assertRaises(ssl.SSLError) as cm, \
- support.captured_stderr() as stderr:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
- self.assertIn("TypeError", stderr.getvalue())
-
- def test_shared_ciphers(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
- if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
- client_context.set_ciphers("AES128:AES256")
- server_context.set_ciphers("AES256")
- alg1 = "AES256"
- alg2 = "AES-256"
- else:
- client_context.set_ciphers("AES:3DES")
- server_context.set_ciphers("3DES")
- alg1 = "3DES"
- alg2 = "DES-CBC3"
-
- stats = server_params_test(client_context, server_context)
- ciphers = stats['server_shared_ciphers'][0]
- self.assertGreater(len(ciphers), 0)
- for name, tls_version, bits in ciphers:
- if not alg1 in name.split("-") and alg2 not in name:
- self.fail(name)
-
- def test_read_write_after_close_raises_valuerror(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
-
- with server:
- s = context.wrap_socket(socket.socket())
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
+ self.assertIn("TypeError", stderr.getvalue())
+
+ def test_shared_ciphers(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
+ client_context.set_ciphers("AES128:AES256")
+ server_context.set_ciphers("AES256")
+ alg1 = "AES256"
+ alg2 = "AES-256"
+ else:
+ client_context.set_ciphers("AES:3DES")
+ server_context.set_ciphers("3DES")
+ alg1 = "3DES"
+ alg2 = "DES-CBC3"
+
+ stats = server_params_test(client_context, server_context)
+ ciphers = stats['server_shared_ciphers'][0]
+ self.assertGreater(len(ciphers), 0)
+ for name, tls_version, bits in ciphers:
+ if not alg1 in name.split("-") and alg2 not in name:
+ self.fail(name)
+
+ def test_read_write_after_close_raises_valuerror(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+
+ with server:
+ s = context.wrap_socket(socket.socket())
+ s.connect((HOST, server.port))
+ s.close()
+
+ self.assertRaises(ValueError, s.read, 1024)
+ self.assertRaises(ValueError, s.write, b'hello')
+
+ def test_sendfile(self):
+ TEST_DATA = b"x" * 512
+ with open(support.TESTFN, 'wb') as f:
+ f.write(TEST_DATA)
+ self.addCleanup(support.unlink, support.TESTFN)
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- s.close()
+ with open(support.TESTFN, 'rb') as file:
+ s.sendfile(file)
+ self.assertEqual(s.recv(1024), TEST_DATA)
+
+ def test_session(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+
+ # first connection without session
+ stats = server_params_test(client_context, server_context)
+ session = stats['session']
+ self.assertTrue(session.id)
+ self.assertGreater(session.time, 0)
+ self.assertGreater(session.timeout, 0)
+ self.assertTrue(session.has_ticket)
+ if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
+ self.assertGreater(session.ticket_lifetime_hint, 0)
+ self.assertFalse(stats['session_reused'])
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 1)
+ self.assertEqual(sess_stat['hits'], 0)
+
+ # reuse session
+ stats = server_params_test(client_context, server_context, session=session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 2)
+ self.assertEqual(sess_stat['hits'], 1)
+ self.assertTrue(stats['session_reused'])
+ session2 = stats['session']
+ self.assertEqual(session2.id, session.id)
+ self.assertEqual(session2, session)
+ self.assertIsNot(session2, session)
+ self.assertGreaterEqual(session2.time, session.time)
+ self.assertGreaterEqual(session2.timeout, session.timeout)
+
+ # another one without session
+ stats = server_params_test(client_context, server_context)
+ self.assertFalse(stats['session_reused'])
+ session3 = stats['session']
+ self.assertNotEqual(session3.id, session.id)
+ self.assertNotEqual(session3, session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 3)
+ self.assertEqual(sess_stat['hits'], 1)
+
+ # reuse session again
+ stats = server_params_test(client_context, server_context, session=session)
+ self.assertTrue(stats['session_reused'])
+ session4 = stats['session']
+ self.assertEqual(session4.id, session.id)
+ self.assertEqual(session4, session)
+ self.assertGreaterEqual(session4.time, session.time)
+ self.assertGreaterEqual(session4.timeout, session.timeout)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 4)
+ self.assertEqual(sess_stat['hits'], 2)
+
+ def test_session_handling(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+
+ context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context2.verify_mode = ssl.CERT_REQUIRED
+ context2.load_verify_locations(CERTFILE)
+ context2.load_cert_chain(CERTFILE)
+
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ # session is None before handshake
+ self.assertEqual(s.session, None)
+ self.assertEqual(s.session_reused, None)
+ s.connect((HOST, server.port))
+ session = s.session
+ self.assertTrue(session)
+ with self.assertRaises(TypeError) as e:
+ s.session = object
+ self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
- self.assertRaises(ValueError, s.read, 1024)
- self.assertRaises(ValueError, s.write, b'hello')
-
- def test_sendfile(self):
- TEST_DATA = b"x" * 512
- with open(support.TESTFN, 'wb') as f:
- f.write(TEST_DATA)
- self.addCleanup(support.unlink, support.TESTFN)
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- with open(support.TESTFN, 'rb') as file:
- s.sendfile(file)
- self.assertEqual(s.recv(1024), TEST_DATA)
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ # cannot set session after handshake
+ with self.assertRaises(ValueError) as e:
+ s.session = session
+ self.assertEqual(str(e.exception),
+ 'Cannot set session after handshake.')
- def test_session(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
-
- # first connection without session
- stats = server_params_test(client_context, server_context)
- session = stats['session']
- self.assertTrue(session.id)
- self.assertGreater(session.time, 0)
- self.assertGreater(session.timeout, 0)
- self.assertTrue(session.has_ticket)
- if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
- self.assertGreater(session.ticket_lifetime_hint, 0)
- self.assertFalse(stats['session_reused'])
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 1)
- self.assertEqual(sess_stat['hits'], 0)
-
- # reuse session
- stats = server_params_test(client_context, server_context, session=session)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 2)
- self.assertEqual(sess_stat['hits'], 1)
- self.assertTrue(stats['session_reused'])
- session2 = stats['session']
- self.assertEqual(session2.id, session.id)
- self.assertEqual(session2, session)
- self.assertIsNot(session2, session)
- self.assertGreaterEqual(session2.time, session.time)
- self.assertGreaterEqual(session2.timeout, session.timeout)
-
- # another one without session
- stats = server_params_test(client_context, server_context)
- self.assertFalse(stats['session_reused'])
- session3 = stats['session']
- self.assertNotEqual(session3.id, session.id)
- self.assertNotEqual(session3, session)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 3)
- self.assertEqual(sess_stat['hits'], 1)
-
- # reuse session again
- stats = server_params_test(client_context, server_context, session=session)
- self.assertTrue(stats['session_reused'])
- session4 = stats['session']
- self.assertEqual(session4.id, session.id)
- self.assertEqual(session4, session)
- self.assertGreaterEqual(session4.time, session.time)
- self.assertGreaterEqual(session4.timeout, session.timeout)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 4)
- self.assertEqual(sess_stat['hits'], 2)
-
- def test_session_handling(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
-
- context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context2.verify_mode = ssl.CERT_REQUIRED
- context2.load_verify_locations(CERTFILE)
- context2.load_cert_chain(CERTFILE)
-
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- # session is None before handshake
- self.assertEqual(s.session, None)
- self.assertEqual(s.session_reused, None)
- s.connect((HOST, server.port))
- session = s.session
- self.assertTrue(session)
- with self.assertRaises(TypeError) as e:
- s.session = object
- self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
+ with context.wrap_socket(socket.socket()) as s:
+ # can set session before handshake and before the
+ # connection was established
+ s.session = session
+ s.connect((HOST, server.port))
+ self.assertEqual(s.session.id, session.id)
+ self.assertEqual(s.session, session)
+ self.assertEqual(s.session_reused, True)
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- # cannot set session after handshake
- with self.assertRaises(ValueError) as e:
- s.session = session
- self.assertEqual(str(e.exception),
- 'Cannot set session after handshake.')
-
- with context.wrap_socket(socket.socket()) as s:
- # can set session before handshake and before the
- # connection was established
+ with context2.wrap_socket(socket.socket()) as s:
+ # cannot re-use session with a different SSLContext
+ with self.assertRaises(ValueError) as e:
s.session = session
s.connect((HOST, server.port))
- self.assertEqual(s.session.id, session.id)
- self.assertEqual(s.session, session)
- self.assertEqual(s.session_reused, True)
-
- with context2.wrap_socket(socket.socket()) as s:
- # cannot re-use session with a different SSLContext
- with self.assertRaises(ValueError) as e:
- s.session = session
- s.connect((HOST, server.port))
- self.assertEqual(str(e.exception),
- 'Session refers to a different SSLContext.')
+ self.assertEqual(str(e.exception),
+ 'Session refers to a different SSLContext.')
def test_main(verbose=False):
@@ -3610,22 +3603,17 @@ def test_main(verbose=False):
tests = [
ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
- SimpleBackgroundTests,
+ SimpleBackgroundTests, ThreadedTests,
]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
- if _have_threads:
- thread_info = support.threading_setup()
- if thread_info:
- tests.append(ThreadedTests)
-
+ thread_info = support.threading_setup()
try:
support.run_unittest(*tests)
finally:
- if _have_threads:
- support.threading_cleanup(*thread_info)
+ support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()