summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ssl.py1
-rw-r--r--Lib/test/test_ssl.py156
2 files changed, 100 insertions, 57 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 24d3771..585105d 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -63,6 +63,7 @@ from _ssl import _SSLContext, SSLError
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
PROTOCOL_TLSv1)
+from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1
from _ssl import RAND_status, RAND_egd, RAND_add
from _ssl import (
SSL_ERROR_ZERO_RETURN,
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index c9dc47a..c464440 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -57,6 +57,14 @@ def handle_error(prefix):
if support.verbose:
sys.stdout.write(prefix + exc_format)
+def can_clear_options():
+ # 0.9.8m or higher
+ return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 13, 15)
+
+def no_sslv2_implies_sslv3_hello():
+ # 0.9.7h or higher
+ return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
+
class BasicSocketTests(unittest.TestCase):
@@ -189,6 +197,26 @@ class ContextTests(unittest.TestCase):
with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
ctx.set_ciphers("^$:,;?*'dorothyx")
+ def test_options(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ # OP_ALL is the default value
+ self.assertEqual(ssl.OP_ALL, ctx.options)
+ ctx.options |= ssl.OP_NO_SSLv2
+ self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2,
+ ctx.options)
+ ctx.options |= ssl.OP_NO_SSLv3
+ self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3,
+ ctx.options)
+ if can_clear_options():
+ ctx.options = (ctx.options & ~ssl.OP_NO_SSLv2) | ssl.OP_NO_TLSv1
+ self.assertEqual(ssl.OP_ALL | ssl.OP_NO_TLSv1 | ssl.OP_NO_SSLv3,
+ ctx.options)
+ ctx.options = 0
+ self.assertEqual(0, ctx.options)
+ else:
+ with self.assertRaises(ValueError):
+ ctx.options = 0
+
def test_verify(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
# Default value
@@ -445,12 +473,8 @@ else:
def wrap_conn(self):
try:
- self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
- certfile=self.server.certificate,
- ssl_version=self.server.protocol,
- ca_certs=self.server.cacerts,
- cert_reqs=self.server.certreqs,
- ciphers=self.server.ciphers)
+ self.sslconn = self.server.context.wrap_socket(
+ self.sock, server_side=True)
except ssl.SSLError:
# XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate,
@@ -462,7 +486,7 @@ else:
self.close()
return False
else:
- if self.server.certreqs == ssl.CERT_REQUIRED:
+ if self.server.context.verify_mode == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if support.verbose and self.server.chatty:
sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
@@ -542,19 +566,24 @@ else:
# harness, we want to stop the server
self.server.stop()
- def __init__(self, certificate, ssl_version=None,
+ def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- ciphers=None):
- if ssl_version is None:
- ssl_version = ssl.PROTOCOL_TLSv1
- if certreqs is None:
- certreqs = ssl.CERT_NONE
- self.certificate = certificate
- self.protocol = ssl_version
- self.certreqs = certreqs
- self.cacerts = cacerts
- self.ciphers = ciphers
+ ciphers=None, context=None):
+ if context:
+ self.context = context
+ else:
+ self.context = ssl.SSLContext(ssl_version
+ if ssl_version is not None
+ else ssl.PROTOCOL_TLSv1)
+ self.context.verify_mode = (certreqs if certreqs is not None
+ else ssl.CERT_NONE)
+ if cacerts:
+ self.context.load_verify_locations(cacerts)
+ if certificate:
+ self.context.load_cert_chain(certificate)
+ if ciphers:
+ self.context.set_ciphers(ciphers)
self.chatty = chatty
self.connectionchatty = connectionchatty
self.starttls_server = starttls_server
@@ -820,18 +849,13 @@ else:
server.stop()
server.join()
- def server_params_test(certfile, protocol, certreqs, cacertsfile,
- client_certfile, client_protocol=None, indata=b"FOO\n",
- ciphers=None, chatty=True, connectionchatty=False):
+ def server_params_test(client_context, server_context, indata=b"FOO\n",
+ chatty=True, connectionchatty=False):
"""
Launch a server, connect a client to it and try various reads
and writes.
"""
- server = ThreadedEchoServer(certfile,
- certreqs=certreqs,
- ssl_version=protocol,
- cacerts=cacertsfile,
- ciphers=ciphers,
+ server = ThreadedEchoServer(context=server_context,
chatty=chatty,
connectionchatty=False)
flag = threading.Event()
@@ -839,15 +863,8 @@ else:
# wait for it to start
flag.wait()
# try to connect
- if client_protocol is None:
- client_protocol = protocol
try:
- s = ssl.wrap_socket(socket.socket(),
- certfile=client_certfile,
- ca_certs=cacertsfile,
- ciphers=ciphers,
- cert_reqs=certreqs,
- ssl_version=client_protocol)
+ s = client_context.wrap_socket(socket.socket())
s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
@@ -873,10 +890,8 @@ else:
server.stop()
server.join()
- def try_protocol_combo(server_protocol,
- client_protocol,
- expect_success,
- certsreqs=None):
+ def try_protocol_combo(server_protocol, client_protocol, expect_success,
+ certsreqs=None, server_options=0, client_options=0):
if certsreqs is None:
certsreqs = ssl.CERT_NONE
certtype = {
@@ -890,14 +905,21 @@ else:
(ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol),
certtype))
- try:
+ client_context = ssl.SSLContext(client_protocol)
+ client_context.options = ssl.OP_ALL | client_options
+ server_context = ssl.SSLContext(server_protocol)
+ server_context.options = ssl.OP_ALL | server_options
+ for ctx in (client_context, server_context):
+ ctx.verify_mode = certsreqs
# NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
# will send an SSLv3 hello (rather than SSLv2) starting from
# OpenSSL 1.0.0 (see issue #8322).
- server_params_test(CERTFILE, server_protocol, certsreqs,
- CERTFILE, CERTFILE, client_protocol,
- ciphers="ALL", chatty=False,
- connectionchatty=False)
+ ctx.set_ciphers("ALL")
+ ctx.load_cert_chain(CERTFILE)
+ ctx.load_verify_locations(CERTFILE)
+ try:
+ server_params_test(client_context, server_context,
+ chatty=False, connectionchatty=False)
# Protocol mismatch can result in either an SSLError, or a
# "Connection reset by peer" error.
except ssl.SSLError:
@@ -920,30 +942,27 @@ else:
"""Basic test of an SSL client connecting to a server"""
if support.verbose:
sys.stdout.write("\n")
- server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
- CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
- chatty=True, connectionchatty=True)
+ for protocol in PROTOCOLS:
+ context = ssl.SSLContext(protocol)
+ context.load_cert_chain(CERTFILE)
+ server_params_test(context, context,
+ chatty=True, connectionchatty=True)
def test_getpeercert(self):
if support.verbose:
sys.stdout.write("\n")
- s2 = socket.socket()
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_SSLv23,
- cacerts=CERTFILE,
- chatty=False)
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
- s = ssl.wrap_socket(socket.socket(),
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_REQUIRED,
- ssl_version=ssl.PROTOCOL_SSLv23)
+ s = context.wrap_socket(socket.socket())
s.connect((HOST, server.port))
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
@@ -1031,6 +1050,15 @@ else:
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+ # SSLv23 client with specific SSL options
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv2)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True,
+ client_options=ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True,
+ client_options=ssl.OP_NO_TLSv1)
def test_protocol_sslv23(self):
"""Connecting to an SSLv23 server with various client options"""
@@ -1056,6 +1084,16 @@ else:
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ # Server with specific SSL options
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
+ server_options=ssl.OP_NO_SSLv3)
+ # Will choose TLSv1
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
+ server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
+ server_options=ssl.OP_NO_TLSv1)
+
+
def test_protocol_sslv3(self):
"""Connecting to an SSLv3 server with various client options"""
if support.verbose:
@@ -1066,6 +1104,10 @@ else:
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True,
+ client_options=ssl.OP_NO_SSLv2)
def test_protocol_tlsv1(self):
"""Connecting to a TLSv1 server with various client options"""