summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py477
1 files changed, 440 insertions, 37 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 779b622..ea619fd 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -86,6 +86,12 @@ def have_verify_flags():
# 0.9.8 or higher
return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
+def utc_offset(): #NOTE: ignore issues like #1647654
+ # local time = utc time + utc offset
+ if time.daylight and time.localtime().tm_isdst > 0:
+ return -time.altzone # seconds
+ return -time.timezone
+
def asn1time(cert_time):
# Some versions of OpenSSL ignore seconds, see #18207
# 0.9.8.i
@@ -134,6 +140,14 @@ class BasicSocketTests(unittest.TestCase):
self.assertIn(ssl.HAS_SNI, {True, False})
self.assertIn(ssl.HAS_ECDH, {True, False})
+ def test_str_for_enums(self):
+ # Make sure that the PROTOCOL_* constants have enum-like string
+ # reprs.
+ proto = ssl.PROTOCOL_SSLv23
+ self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23')
+ ctx = ssl.SSLContext(proto)
+ self.assertIs(ctx.protocol, proto)
+
def test_random(self):
v = ssl.RAND_status()
if support.verbose:
@@ -298,10 +312,10 @@ class BasicSocketTests(unittest.TestCase):
# Version string as returned by {Open,Libre}SSL, the format might change
if "LibreSSL" in s:
self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)),
- (s, t))
+ (s, t, hex(n)))
else:
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
- (s, t))
+ (s, t, hex(n)))
@support.cpython_only
def test_refcycle(self):
@@ -369,6 +383,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(ssl.CertificateError,
ssl.match_hostname, cert, hostname)
+ # -- Hostname matching --
+
cert = {'subject': ((('commonName', 'example.com'),),)}
ok(cert, 'example.com')
ok(cert, 'ExAmple.cOm')
@@ -454,6 +470,28 @@ class BasicSocketTests(unittest.TestCase):
# Only commonName is considered
fail(cert, 'California')
+ # -- IPv4 matching --
+ cert = {'subject': ((('commonName', 'example.com'),),),
+ 'subjectAltName': (('DNS', 'example.com'),
+ ('IP Address', '10.11.12.13'),
+ ('IP Address', '14.15.16.17'))}
+ ok(cert, '10.11.12.13')
+ ok(cert, '14.15.16.17')
+ fail(cert, '14.15.16.18')
+ fail(cert, 'example.net')
+
+ # -- IPv6 matching --
+ cert = {'subject': ((('commonName', 'example.com'),),),
+ 'subjectAltName': (('DNS', 'example.com'),
+ ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
+ ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
+ ok(cert, '2001::cafe')
+ ok(cert, '2003::baba')
+ fail(cert, '2003::bebe')
+ fail(cert, 'example.net')
+
+ # -- Miscellaneous --
+
# Neither commonName nor subjectAltName
cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
'subject': ((('countryName', 'US'),),
@@ -505,9 +543,14 @@ class BasicSocketTests(unittest.TestCase):
def test_unknown_channel_binding(self):
# should raise ValueError for unknown type
s = socket.socket(socket.AF_INET)
- with ssl.wrap_socket(s) as ss:
+ s.bind(('127.0.0.1', 0))
+ s.listen()
+ c = socket.socket(socket.AF_INET)
+ c.connect(s.getsockname())
+ with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss:
with self.assertRaises(ValueError):
ss.get_channel_binding("unknown-type")
+ s.close()
@unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
"'tls-unique' channel binding not available")
@@ -648,6 +691,71 @@ class BasicSocketTests(unittest.TestCase):
ctx.wrap_socket(s)
self.assertEqual(str(cx.exception), "only stream sockets are supported")
+ def cert_time_ok(self, timestring, timestamp):
+ self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
+
+ def cert_time_fail(self, timestring):
+ with self.assertRaises(ValueError):
+ ssl.cert_time_to_seconds(timestring)
+
+ @unittest.skipUnless(utc_offset(),
+ 'local time needs to be different from UTC')
+ def test_cert_time_to_seconds_timezone(self):
+ # Issue #19940: ssl.cert_time_to_seconds() returns wrong
+ # results if local timezone is not UTC
+ self.cert_time_ok("May 9 00:00:00 2007 GMT", 1178668800.0)
+ self.cert_time_ok("Jan 5 09:34:43 2018 GMT", 1515144883.0)
+
+ def test_cert_time_to_seconds(self):
+ timestring = "Jan 5 09:34:43 2018 GMT"
+ ts = 1515144883.0
+ self.cert_time_ok(timestring, ts)
+ # accept keyword parameter, assert its name
+ self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
+ # accept both %e and %d (space or zero generated by strftime)
+ self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
+ # case-insensitive
+ self.cert_time_ok("JaN 5 09:34:43 2018 GmT", ts)
+ self.cert_time_fail("Jan 5 09:34 2018 GMT") # no seconds
+ self.cert_time_fail("Jan 5 09:34:43 2018") # no GMT
+ self.cert_time_fail("Jan 5 09:34:43 2018 UTC") # not GMT timezone
+ self.cert_time_fail("Jan 35 09:34:43 2018 GMT") # invalid day
+ self.cert_time_fail("Jon 5 09:34:43 2018 GMT") # invalid month
+ self.cert_time_fail("Jan 5 24:00:00 2018 GMT") # invalid hour
+ self.cert_time_fail("Jan 5 09:60:43 2018 GMT") # invalid minute
+
+ newyear_ts = 1230768000.0
+ # leap seconds
+ self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
+ # same timestamp
+ self.cert_time_ok("Jan 1 00:00:00 2009 GMT", newyear_ts)
+
+ self.cert_time_ok("Jan 5 09:34:59 2018 GMT", 1515144899)
+ # allow 60th second (even if it is not a leap second)
+ self.cert_time_ok("Jan 5 09:34:60 2018 GMT", 1515144900)
+ # allow 2nd leap second for compatibility with time.strptime()
+ self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901)
+ self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds
+
+ # no special treatement for the special value:
+ # 99991231235959Z (rfc 5280)
+ self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
+
+ @support.run_with_locale('LC_ALL', '')
+ def test_cert_time_to_seconds_locale(self):
+ # `cert_time_to_seconds()` should be locale independent
+
+ def local_february_name():
+ return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
+
+ if local_february_name().lower() == 'feb':
+ self.skipTest("locale-specific month name needs to be "
+ "different from C locale")
+
+ # locale-independent
+ self.cert_time_ok("Feb 9 00:00:00 2007 GMT", 1170979200.0)
+ self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT")
+
class ContextTests(unittest.TestCase):
@@ -1157,7 +1265,7 @@ class SSLErrorTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
- s.listen(5)
+ s.listen()
c = socket.socket()
c.connect(s.getsockname())
c.setblocking(False)
@@ -1170,6 +1278,69 @@ class SSLErrorTests(unittest.TestCase):
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
+class MemoryBIOTests(unittest.TestCase):
+
+ def test_read_write(self):
+ bio = ssl.MemoryBIO()
+ bio.write(b'foo')
+ self.assertEqual(bio.read(), b'foo')
+ self.assertEqual(bio.read(), b'')
+ bio.write(b'foo')
+ bio.write(b'bar')
+ self.assertEqual(bio.read(), b'foobar')
+ self.assertEqual(bio.read(), b'')
+ bio.write(b'baz')
+ self.assertEqual(bio.read(2), b'ba')
+ self.assertEqual(bio.read(1), b'z')
+ self.assertEqual(bio.read(1), b'')
+
+ def test_eof(self):
+ bio = ssl.MemoryBIO()
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(), b'')
+ self.assertFalse(bio.eof)
+ bio.write(b'foo')
+ self.assertFalse(bio.eof)
+ bio.write_eof()
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(2), b'fo')
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(1), b'o')
+ self.assertTrue(bio.eof)
+ self.assertEqual(bio.read(), b'')
+ self.assertTrue(bio.eof)
+
+ def test_pending(self):
+ bio = ssl.MemoryBIO()
+ self.assertEqual(bio.pending, 0)
+ bio.write(b'foo')
+ self.assertEqual(bio.pending, 3)
+ for i in range(3):
+ bio.read(1)
+ self.assertEqual(bio.pending, 3-i-1)
+ for i in range(3):
+ bio.write(b'x')
+ self.assertEqual(bio.pending, i+1)
+ bio.read()
+ self.assertEqual(bio.pending, 0)
+
+ def test_buffer_types(self):
+ bio = ssl.MemoryBIO()
+ bio.write(b'foo')
+ self.assertEqual(bio.read(), b'foo')
+ bio.write(bytearray(b'bar'))
+ self.assertEqual(bio.read(), b'bar')
+ bio.write(memoryview(b'baz'))
+ self.assertEqual(bio.read(), b'baz')
+
+ def test_error_types(self):
+ bio = ssl.MemoryBIO()
+ self.assertRaises(TypeError, bio.write, 'foo')
+ self.assertRaises(TypeError, bio.write, None)
+ self.assertRaises(TypeError, bio.write, True)
+ self.assertRaises(TypeError, bio.write, 1)
+
+
class NetworkedTests(unittest.TestCase):
def test_connect(self):
@@ -1397,14 +1568,12 @@ class NetworkedTests(unittest.TestCase):
def test_get_server_certificate(self):
def _test_get_server_certificate(host, port, cert=None):
with support.transient_internet(host):
- pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23)
+ pem = ssl.get_server_certificate((host, port))
if not pem:
self.fail("No server certificate on %s:%s!" % (host, port))
try:
pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23,
ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
@@ -1414,7 +1583,6 @@ class NetworkedTests(unittest.TestCase):
self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23,
ca_certs=cert)
if not pem:
self.fail("No server certificate on %s:%s!" % (host, port))
@@ -1500,6 +1668,93 @@ class NetworkedTests(unittest.TestCase):
self.assertIs(ss.context, ctx2)
self.assertIs(ss._sslobj.context, ctx2)
+
+class NetworkedBIOTests(unittest.TestCase):
+
+ def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
+ # A simple IO loop. Call func(*args) depending on the error we get
+ # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
+ timeout = kwargs.get('timeout', 10)
+ count = 0
+ while True:
+ errno = None
+ count += 1
+ try:
+ ret = func(*args)
+ except ssl.SSLError as e:
+ # Note that we get a spurious -1/SSL_ERROR_SYSCALL for
+ # non-blocking IO. The SSL_shutdown manpage hints at this.
+ # It *should* be safe to just ignore SYS_ERROR_SYSCALL because
+ # with a Memory BIO there's no syscalls (for IO at least).
+ if e.errno not in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ raise
+ errno = e.errno
+ # Get any data from the outgoing BIO irrespective of any error, and
+ # send it to the socket.
+ buf = outgoing.read()
+ sock.sendall(buf)
+ # If there's no error, we're done. For WANT_READ, we need to get
+ # data from the socket and put it in the incoming BIO.
+ if errno is None:
+ break
+ elif errno == ssl.SSL_ERROR_WANT_READ:
+ buf = sock.recv(32768)
+ if buf:
+ incoming.write(buf)
+ else:
+ incoming.write_eof()
+ if support.verbose:
+ sys.stdout.write("Needed %d calls to complete %s().\n"
+ % (count, func.__name__))
+ return ret
+
+ def test_handshake(self):
+ with support.transient_internet("svn.python.org"):
+ sock = socket.socket(socket.AF_INET)
+ sock.connect(("svn.python.org", 443))
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT)
+ ctx.check_hostname = True
+ sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org')
+ self.assertIs(sslobj._sslobj.owner, sslobj)
+ self.assertIsNone(sslobj.cipher())
+ self.assertIsNone(sslobj.shared_ciphers())
+ self.assertRaises(ValueError, sslobj.getpeercert)
+ if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
+ self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
+ self.assertTrue(sslobj.cipher())
+ self.assertIsNone(sslobj.shared_ciphers())
+ self.assertTrue(sslobj.getpeercert())
+ if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
+ self.assertTrue(sslobj.get_channel_binding('tls-unique'))
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
+ self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
+ sock.close()
+
+ def test_read_write_data(self):
+ with support.transient_internet("svn.python.org"):
+ sock = socket.socket(socket.AF_INET)
+ sock.connect(("svn.python.org", 443))
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_NONE
+ sslobj = ctx.wrap_bio(incoming, outgoing, False)
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
+ req = b'GET / HTTP/1.0\r\n\r\n'
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
+ buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
+ self.assertEqual(buf[:5], b'HTTP/')
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
+ sock.close()
+
+
try:
import threading
except ImportError:
@@ -1531,7 +1786,8 @@ else:
try:
self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True)
- self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
except (ssl.SSLError, ConnectionResetError) as e:
# We treat ConnectionResetError as though it were an
# SSLError - OpenSSL on Ubuntu abruptly closes the
@@ -1548,6 +1804,7 @@ else:
self.close()
return False
else:
+ self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
if self.server.context.verify_mode == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if support.verbose and self.server.chatty:
@@ -1638,7 +1895,8 @@ else:
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- npn_protocols=None, ciphers=None, context=None):
+ npn_protocols=None, alpn_protocols=None,
+ ciphers=None, context=None):
if context:
self.context = context
else:
@@ -1653,6 +1911,8 @@ else:
self.context.load_cert_chain(certificate)
if npn_protocols:
self.context.set_npn_protocols(npn_protocols)
+ if alpn_protocols:
+ self.context.set_alpn_protocols(alpn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.chatty = chatty
@@ -1662,7 +1922,9 @@ else:
self.port = support.bind_port(self.sock)
self.flag = None
self.active = False
- self.selected_protocols = []
+ self.selected_npn_protocols = []
+ self.selected_alpn_protocols = []
+ self.shared_ciphers = []
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
@@ -1682,7 +1944,7 @@ else:
def run(self):
self.sock.settimeout(0.05)
- self.sock.listen(5)
+ self.sock.listen()
self.active = True
if self.flag:
# signal an event
@@ -1888,14 +2150,25 @@ else:
'compression': s.compression(),
'cipher': s.cipher(),
'peercert': s.getpeercert(),
- 'client_npn_protocol': s.selected_npn_protocol()
+ 'client_alpn_protocol': s.selected_alpn_protocol(),
+ 'client_npn_protocol': s.selected_npn_protocol(),
+ 'version': s.version(),
})
s.close()
- stats['server_npn_protocols'] = server.selected_protocols
+ stats['server_alpn_protocols'] = server.selected_alpn_protocols
+ stats['server_npn_protocols'] = server.selected_npn_protocols
+ stats['server_shared_ciphers'] = server.shared_ciphers
return stats
def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0):
+ """
+ Try to SSL-connect using *client_protocol* to *server_protocol*.
+ If *expect_success* is true, assert that the connection succeeds,
+ if it's false, assert that the connection fails.
+ Also, if *expect_success* is a string, assert that it is the protocol
+ version actually used by the connection.
+ """
if certsreqs is None:
certsreqs = ssl.CERT_NONE
certtype = {
@@ -1925,8 +2198,8 @@ else:
ctx.load_cert_chain(CERTFILE)
ctx.load_verify_locations(CERTFILE)
try:
- server_params_test(client_context, server_context,
- chatty=False, connectionchatty=False)
+ stats = server_params_test(client_context, server_context,
+ chatty=False, connectionchatty=False)
# Protocol mismatch can result in either an SSLError, or a
# "Connection reset by peer" error.
except ssl.SSLError:
@@ -1941,6 +2214,10 @@ else:
"Client protocol %s succeeded with server protocol %s!"
% (ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol)))
+ elif (expect_success is not True
+ and expect_success != stats['version']):
+ raise AssertionError("version mismatch: expected %r, got %r"
+ % (expect_success, stats['version']))
class ThreadedTests(unittest.TestCase):
@@ -2108,7 +2385,7 @@ else:
# and sets Event `listener_gone` to let the main thread know
# the socket is gone.
def listener():
- s.listen(5)
+ s.listen()
listener_ready.set()
newsock, addr = s.accept()
newsock.close()
@@ -2173,19 +2450,19 @@ else:
" SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
% str(x))
if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3')
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
# Server with specific SSL options
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2205,9 +2482,9 @@ else:
"""Connecting to an SSLv3 server with various client options"""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
@@ -2215,7 +2492,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
if no_sslv2_implies_sslv3_hello():
# No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True,
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3',
client_options=ssl.OP_NO_SSLv2)
@skip_if_broken_ubuntu_ssl
@@ -2223,9 +2500,9 @@ else:
"""Connecting to a TLSv1 server with various client options"""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2241,7 +2518,7 @@ else:
Testing against older TLS versions."""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2249,7 +2526,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1_1)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
@@ -2262,7 +2539,7 @@ else:
Testing against older TLS versions."""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True,
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
@@ -2272,7 +2549,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1_2)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
@@ -2508,6 +2785,36 @@ else:
s.write(b"over\n")
s.close()
+ def test_nonblocking_send(self):
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ s.setblocking(False)
+
+ # If we keep sending data, at some point the buffers
+ # will be full and the call will block
+ buf = bytearray(8192)
+ def fill_buffer():
+ while True:
+ s.send(buf)
+ self.assertRaises((ssl.SSLWantWriteError,
+ ssl.SSLWantReadError), fill_buffer)
+
+ # Now read all the output and discard it
+ s.setblocking(True)
+ s.close()
+
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout
server = socket.socket(socket.AF_INET)
@@ -2517,7 +2824,7 @@ else:
finish = False
def serve():
- server.listen(5)
+ server.listen()
started.set()
conns = []
while not finish:
@@ -2574,7 +2881,7 @@ else:
peer = None
def serve():
nonlocal remote, peer
- server.listen(5)
+ server.listen()
# Block on the accept and wait on the connection to close.
evt.set()
remote, peer = server.accept()
@@ -2624,6 +2931,21 @@ else:
s.connect((HOST, server.port))
self.assertIn("no shared cipher", str(server.conn_errors[0]))
+ def test_version_basic(self):
+ """
+ Basic tests for SSLSocket.version().
+ More tests are done in the test_protocol_*() methods.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ self.assertIs(s.version(), None)
+ s.connect((HOST, server.port))
+ self.assertEqual(s.version(), "TLSv1")
+ self.assertIs(s.version(), None)
+
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
def test_default_ecdh_curve(self):
# Issue #21015: elliptic curve-based Diffie Hellman key exchange
@@ -2733,6 +3055,55 @@ else:
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
self.fail("Non-DH cipher: " + cipher[0])
+ def test_selected_alpn_protocol(self):
+ # selected_alpn_protocol() is None unless ALPN is used.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+ def test_selected_alpn_protocol_if_server_uses_alpn(self):
+ # selected_alpn_protocol() is None unless ALPN is used by the client.
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_verify_locations(CERTFILE)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(['foo', 'bar'])
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+ def test_alpn_protocols(self):
+ server_protocols = ['foo', 'bar', 'milkshake']
+ protocol_tests = [
+ (['foo', 'bar'], 'foo'),
+ (['bar', 'foo'], 'foo'),
+ (['milkshake'], 'milkshake'),
+ (['http/3.0', 'http/4.0'], None)
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_alpn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_alpn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_alpn_protocols'][-1] \
+ if len(stats['server_alpn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
@@ -2873,6 +3244,20 @@ else:
self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
self.assertIn("TypeError", stderr.getvalue())
+ def test_shared_ciphers(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ client_context.set_ciphers("RC4")
+ server_context.set_ciphers("AES:RC4")
+ stats = server_params_test(client_context, server_context)
+ ciphers = stats['server_shared_ciphers'][0]
+ self.assertGreater(len(ciphers), 0)
+ for name, tls_version, bits in ciphers:
+ self.assertIn("RC4", name.split("-"))
+
def test_read_write_after_close_raises_valuerror(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.verify_mode = ssl.CERT_REQUIRED
@@ -2888,6 +3273,23 @@ else:
self.assertRaises(ValueError, s.read, 1024)
self.assertRaises(ValueError, s.write, b'hello')
+ def test_sendfile(self):
+ TEST_DATA = b"x" * 512
+ with open(support.TESTFN, 'wb') as f:
+ f.write(TEST_DATA)
+ self.addCleanup(support.unlink, support.TESTFN)
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ with open(support.TESTFN, 'rb') as file:
+ s.sendfile(file)
+ self.assertEqual(s.recv(1024), TEST_DATA)
+
def test_main(verbose=False):
if support.verbose:
@@ -2921,10 +3323,11 @@ def test_main(verbose=False):
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
- tests = [ContextTests, BasicSocketTests, SSLErrorTests]
+ tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
+ tests.append(NetworkedBIOTests)
if _have_threads:
thread_info = support.threading_setup()