summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py642
1 files changed, 549 insertions, 93 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 19ef354..f48103e 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -60,7 +60,7 @@ REMOTE_ROOT_CERT = data_file("selfsigned_pythontestdotnet.pem")
EMPTYCERT = data_file("nullcert.pem")
BADCERT = data_file("badcert.pem")
-WRONGCERT = data_file("XXXnonexisting.pem")
+NONEXISTINGCERT = data_file("XXXnonexisting.pem")
BADKEY = data_file("badkey.pem")
NOKIACERT = data_file("nokia.pem")
NULLBYTECERT = data_file("nullbytecert.pem")
@@ -86,6 +86,12 @@ def have_verify_flags():
# 0.9.8 or higher
return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
+def utc_offset(): #NOTE: ignore issues like #1647654
+ # local time = utc time + utc offset
+ if time.daylight and time.localtime().tm_isdst > 0:
+ return -time.altzone # seconds
+ return -time.timezone
+
def asn1time(cert_time):
# Some versions of OpenSSL ignore seconds, see #18207
# 0.9.8.i
@@ -134,6 +140,14 @@ class BasicSocketTests(unittest.TestCase):
self.assertIn(ssl.HAS_SNI, {True, False})
self.assertIn(ssl.HAS_ECDH, {True, False})
+ def test_str_for_enums(self):
+ # Make sure that the PROTOCOL_* constants have enum-like string
+ # reprs.
+ proto = ssl.PROTOCOL_SSLv23
+ self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23')
+ ctx = ssl.SSLContext(proto)
+ self.assertIs(ctx.protocol, proto)
+
def test_random(self):
v = ssl.RAND_status()
if support.verbose:
@@ -158,6 +172,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(TypeError, ssl.RAND_egd, 1)
self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
ssl.RAND_add("this is a random string", 75.0)
+ ssl.RAND_add(b"this is a random bytes object", 75.0)
+ ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
@unittest.skipUnless(os.name == 'posix', 'requires posix')
def test_random_fork(self):
@@ -298,10 +314,10 @@ class BasicSocketTests(unittest.TestCase):
# Version string as returned by {Open,Libre}SSL, the format might change
if "LibreSSL" in s:
self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)),
- (s, t))
+ (s, t, hex(n)))
else:
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
- (s, t))
+ (s, t, hex(n)))
@support.cpython_only
def test_refcycle(self):
@@ -351,17 +367,42 @@ class BasicSocketTests(unittest.TestCase):
s.connect, (HOST, 8080))
with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
- ssl.wrap_socket(sock, certfile=WRONGCERT)
+ ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
- ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=WRONGCERT)
+ ssl.wrap_socket(sock,
+ certfile=CERTFILE, keyfile=NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
- ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT)
+ ssl.wrap_socket(sock,
+ certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
+ def bad_cert_test(self, certfile):
+ """Check that trying to use the given client certificate fails"""
+ certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
+ certfile)
+ sock = socket.socket()
+ self.addCleanup(sock.close)
+ with self.assertRaises(ssl.SSLError):
+ ssl.wrap_socket(sock,
+ certfile=certfile,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+
+ def test_empty_cert(self):
+ """Wrapping with an empty cert file"""
+ self.bad_cert_test("nullcert.pem")
+
+ def test_malformed_cert(self):
+ """Wrapping with a badly formatted certificate (syntax error)"""
+ self.bad_cert_test("badcert.pem")
+
+ def test_malformed_key(self):
+ """Wrapping with a badly formatted key (syntax error)"""
+ self.bad_cert_test("badkey.pem")
+
def test_match_hostname(self):
def ok(cert, hostname):
ssl.match_hostname(cert, hostname)
@@ -369,6 +410,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(ssl.CertificateError,
ssl.match_hostname, cert, hostname)
+ # -- Hostname matching --
+
cert = {'subject': ((('commonName', 'example.com'),),)}
ok(cert, 'example.com')
ok(cert, 'ExAmple.cOm')
@@ -454,6 +497,28 @@ class BasicSocketTests(unittest.TestCase):
# Only commonName is considered
fail(cert, 'California')
+ # -- IPv4 matching --
+ cert = {'subject': ((('commonName', 'example.com'),),),
+ 'subjectAltName': (('DNS', 'example.com'),
+ ('IP Address', '10.11.12.13'),
+ ('IP Address', '14.15.16.17'))}
+ ok(cert, '10.11.12.13')
+ ok(cert, '14.15.16.17')
+ fail(cert, '14.15.16.18')
+ fail(cert, 'example.net')
+
+ # -- IPv6 matching --
+ cert = {'subject': ((('commonName', 'example.com'),),),
+ 'subjectAltName': (('DNS', 'example.com'),
+ ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
+ ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
+ ok(cert, '2001::cafe')
+ ok(cert, '2003::baba')
+ fail(cert, '2003::bebe')
+ fail(cert, 'example.net')
+
+ # -- Miscellaneous --
+
# Neither commonName nor subjectAltName
cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
'subject': ((('countryName', 'US'),),
@@ -505,9 +570,14 @@ class BasicSocketTests(unittest.TestCase):
def test_unknown_channel_binding(self):
# should raise ValueError for unknown type
s = socket.socket(socket.AF_INET)
- with ssl.wrap_socket(s) as ss:
+ s.bind(('127.0.0.1', 0))
+ s.listen()
+ c = socket.socket(socket.AF_INET)
+ c.connect(s.getsockname())
+ with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss:
with self.assertRaises(ValueError):
ss.get_channel_binding("unknown-type")
+ s.close()
@unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
"'tls-unique' channel binding not available")
@@ -648,6 +718,71 @@ class BasicSocketTests(unittest.TestCase):
ctx.wrap_socket(s)
self.assertEqual(str(cx.exception), "only stream sockets are supported")
+ def cert_time_ok(self, timestring, timestamp):
+ self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
+
+ def cert_time_fail(self, timestring):
+ with self.assertRaises(ValueError):
+ ssl.cert_time_to_seconds(timestring)
+
+ @unittest.skipUnless(utc_offset(),
+ 'local time needs to be different from UTC')
+ def test_cert_time_to_seconds_timezone(self):
+ # Issue #19940: ssl.cert_time_to_seconds() returns wrong
+ # results if local timezone is not UTC
+ self.cert_time_ok("May 9 00:00:00 2007 GMT", 1178668800.0)
+ self.cert_time_ok("Jan 5 09:34:43 2018 GMT", 1515144883.0)
+
+ def test_cert_time_to_seconds(self):
+ timestring = "Jan 5 09:34:43 2018 GMT"
+ ts = 1515144883.0
+ self.cert_time_ok(timestring, ts)
+ # accept keyword parameter, assert its name
+ self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
+ # accept both %e and %d (space or zero generated by strftime)
+ self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
+ # case-insensitive
+ self.cert_time_ok("JaN 5 09:34:43 2018 GmT", ts)
+ self.cert_time_fail("Jan 5 09:34 2018 GMT") # no seconds
+ self.cert_time_fail("Jan 5 09:34:43 2018") # no GMT
+ self.cert_time_fail("Jan 5 09:34:43 2018 UTC") # not GMT timezone
+ self.cert_time_fail("Jan 35 09:34:43 2018 GMT") # invalid day
+ self.cert_time_fail("Jon 5 09:34:43 2018 GMT") # invalid month
+ self.cert_time_fail("Jan 5 24:00:00 2018 GMT") # invalid hour
+ self.cert_time_fail("Jan 5 09:60:43 2018 GMT") # invalid minute
+
+ newyear_ts = 1230768000.0
+ # leap seconds
+ self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
+ # same timestamp
+ self.cert_time_ok("Jan 1 00:00:00 2009 GMT", newyear_ts)
+
+ self.cert_time_ok("Jan 5 09:34:59 2018 GMT", 1515144899)
+ # allow 60th second (even if it is not a leap second)
+ self.cert_time_ok("Jan 5 09:34:60 2018 GMT", 1515144900)
+ # allow 2nd leap second for compatibility with time.strptime()
+ self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901)
+ self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds
+
+ # no special treatement for the special value:
+ # 99991231235959Z (rfc 5280)
+ self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
+
+ @support.run_with_locale('LC_ALL', '')
+ def test_cert_time_to_seconds_locale(self):
+ # `cert_time_to_seconds()` should be locale independent
+
+ def local_february_name():
+ return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
+
+ if local_february_name().lower() == 'feb':
+ self.skipTest("locale-specific month name needs to be "
+ "different from C locale")
+
+ # locale-independent
+ self.cert_time_ok("Feb 9 00:00:00 2007 GMT", 1170979200.0)
+ self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT")
+
class ContextTests(unittest.TestCase):
@@ -734,7 +869,7 @@ class ContextTests(unittest.TestCase):
ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
with self.assertRaises(OSError) as cm:
- ctx.load_cert_chain(WRONGCERT)
+ ctx.load_cert_chain(NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
ctx.load_cert_chain(BADCERT)
@@ -819,7 +954,7 @@ class ContextTests(unittest.TestCase):
self.assertRaises(TypeError, ctx.load_verify_locations)
self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
with self.assertRaises(OSError) as cm:
- ctx.load_verify_locations(WRONGCERT)
+ ctx.load_verify_locations(NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
ctx.load_verify_locations(BADCERT)
@@ -895,7 +1030,7 @@ class ContextTests(unittest.TestCase):
self.assertRaises(TypeError, ctx.load_dh_params)
self.assertRaises(TypeError, ctx.load_dh_params, None)
with self.assertRaises(FileNotFoundError) as cm:
- ctx.load_dh_params(WRONGCERT)
+ ctx.load_dh_params(NONEXISTINGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaises(ssl.SSLError) as cm:
ctx.load_dh_params(CERTFILE)
@@ -1158,7 +1293,7 @@ class SSLErrorTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
- s.listen(5)
+ s.listen()
c = socket.socket()
c.connect(s.getsockname())
c.setblocking(False)
@@ -1171,6 +1306,69 @@ class SSLErrorTests(unittest.TestCase):
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
+class MemoryBIOTests(unittest.TestCase):
+
+ def test_read_write(self):
+ bio = ssl.MemoryBIO()
+ bio.write(b'foo')
+ self.assertEqual(bio.read(), b'foo')
+ self.assertEqual(bio.read(), b'')
+ bio.write(b'foo')
+ bio.write(b'bar')
+ self.assertEqual(bio.read(), b'foobar')
+ self.assertEqual(bio.read(), b'')
+ bio.write(b'baz')
+ self.assertEqual(bio.read(2), b'ba')
+ self.assertEqual(bio.read(1), b'z')
+ self.assertEqual(bio.read(1), b'')
+
+ def test_eof(self):
+ bio = ssl.MemoryBIO()
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(), b'')
+ self.assertFalse(bio.eof)
+ bio.write(b'foo')
+ self.assertFalse(bio.eof)
+ bio.write_eof()
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(2), b'fo')
+ self.assertFalse(bio.eof)
+ self.assertEqual(bio.read(1), b'o')
+ self.assertTrue(bio.eof)
+ self.assertEqual(bio.read(), b'')
+ self.assertTrue(bio.eof)
+
+ def test_pending(self):
+ bio = ssl.MemoryBIO()
+ self.assertEqual(bio.pending, 0)
+ bio.write(b'foo')
+ self.assertEqual(bio.pending, 3)
+ for i in range(3):
+ bio.read(1)
+ self.assertEqual(bio.pending, 3-i-1)
+ for i in range(3):
+ bio.write(b'x')
+ self.assertEqual(bio.pending, i+1)
+ bio.read()
+ self.assertEqual(bio.pending, 0)
+
+ def test_buffer_types(self):
+ bio = ssl.MemoryBIO()
+ bio.write(b'foo')
+ self.assertEqual(bio.read(), b'foo')
+ bio.write(bytearray(b'bar'))
+ self.assertEqual(bio.read(), b'bar')
+ bio.write(memoryview(b'baz'))
+ self.assertEqual(bio.read(), b'baz')
+
+ def test_error_types(self):
+ bio = ssl.MemoryBIO()
+ self.assertRaises(TypeError, bio.write, 'foo')
+ self.assertRaises(TypeError, bio.write, None)
+ self.assertRaises(TypeError, bio.write, True)
+ self.assertRaises(TypeError, bio.write, 1)
+
+
class NetworkedTests(unittest.TestCase):
def test_connect(self):
@@ -1402,14 +1600,12 @@ class NetworkedTests(unittest.TestCase):
def test_get_server_certificate(self):
def _test_get_server_certificate(host, port, cert=None):
with support.transient_internet(host):
- pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23)
+ pem = ssl.get_server_certificate((host, port))
if not pem:
self.fail("No server certificate on %s:%s!" % (host, port))
try:
pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23,
ca_certs=CERTFILE)
except ssl.SSLError as x:
#should fail
@@ -1419,7 +1615,6 @@ class NetworkedTests(unittest.TestCase):
self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
pem = ssl.get_server_certificate((host, port),
- ssl.PROTOCOL_SSLv23,
ca_certs=cert)
if not pem:
self.fail("No server certificate on %s:%s!" % (host, port))
@@ -1505,6 +1700,94 @@ class NetworkedTests(unittest.TestCase):
self.assertIs(ss.context, ctx2)
self.assertIs(ss._sslobj.context, ctx2)
+
+class NetworkedBIOTests(unittest.TestCase):
+
+ def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
+ # A simple IO loop. Call func(*args) depending on the error we get
+ # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
+ timeout = kwargs.get('timeout', 10)
+ count = 0
+ while True:
+ errno = None
+ count += 1
+ try:
+ ret = func(*args)
+ except ssl.SSLError as e:
+ if e.errno not in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE):
+ raise
+ errno = e.errno
+ # Get any data from the outgoing BIO irrespective of any error, and
+ # send it to the socket.
+ buf = outgoing.read()
+ sock.sendall(buf)
+ # If there's no error, we're done. For WANT_READ, we need to get
+ # data from the socket and put it in the incoming BIO.
+ if errno is None:
+ break
+ elif errno == ssl.SSL_ERROR_WANT_READ:
+ buf = sock.recv(32768)
+ if buf:
+ incoming.write(buf)
+ else:
+ incoming.write_eof()
+ if support.verbose:
+ sys.stdout.write("Needed %d calls to complete %s().\n"
+ % (count, func.__name__))
+ return ret
+
+ def test_handshake(self):
+ with support.transient_internet(REMOTE_HOST):
+ sock = socket.socket(socket.AF_INET)
+ sock.connect((REMOTE_HOST, 443))
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ ctx.load_verify_locations(REMOTE_ROOT_CERT)
+ ctx.check_hostname = True
+ sslobj = ctx.wrap_bio(incoming, outgoing, False, REMOTE_HOST)
+ self.assertIs(sslobj._sslobj.owner, sslobj)
+ self.assertIsNone(sslobj.cipher())
+ self.assertIsNone(sslobj.shared_ciphers())
+ self.assertRaises(ValueError, sslobj.getpeercert)
+ if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
+ self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
+ self.assertTrue(sslobj.cipher())
+ self.assertIsNone(sslobj.shared_ciphers())
+ self.assertTrue(sslobj.getpeercert())
+ if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
+ self.assertTrue(sslobj.get_channel_binding('tls-unique'))
+ try:
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
+ except ssl.SSLSyscallError:
+ # self-signed.pythontest.net probably shuts down the TCP
+ # connection without sending a secure shutdown message, and
+ # this is reported as SSL_ERROR_SYSCALL
+ pass
+ self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
+ sock.close()
+
+ def test_read_write_data(self):
+ with support.transient_internet(REMOTE_HOST):
+ sock = socket.socket(socket.AF_INET)
+ sock.connect((REMOTE_HOST, 443))
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_NONE
+ sslobj = ctx.wrap_bio(incoming, outgoing, False)
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
+ req = b'GET / HTTP/1.0\r\n\r\n'
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
+ buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
+ self.assertEqual(buf[:5], b'HTTP/')
+ self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
+ sock.close()
+
+
try:
import threading
except ImportError:
@@ -1536,7 +1819,8 @@ else:
try:
self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True)
- self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
except (ssl.SSLError, ConnectionResetError) as e:
# We treat ConnectionResetError as though it were an
# SSLError - OpenSSL on Ubuntu abruptly closes the
@@ -1553,6 +1837,7 @@ else:
self.close()
return False
else:
+ self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
if self.server.context.verify_mode == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if support.verbose and self.server.chatty:
@@ -1643,7 +1928,8 @@ else:
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- npn_protocols=None, ciphers=None, context=None):
+ npn_protocols=None, alpn_protocols=None,
+ ciphers=None, context=None):
if context:
self.context = context
else:
@@ -1658,6 +1944,8 @@ else:
self.context.load_cert_chain(certificate)
if npn_protocols:
self.context.set_npn_protocols(npn_protocols)
+ if alpn_protocols:
+ self.context.set_alpn_protocols(alpn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.chatty = chatty
@@ -1667,7 +1955,9 @@ else:
self.port = support.bind_port(self.sock)
self.flag = None
self.active = False
- self.selected_protocols = []
+ self.selected_npn_protocols = []
+ self.selected_alpn_protocols = []
+ self.shared_ciphers = []
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
@@ -1687,7 +1977,7 @@ else:
def run(self):
self.sock.settimeout(0.05)
- self.sock.listen(5)
+ self.sock.listen()
self.active = True
if self.flag:
# signal an event
@@ -1826,36 +2116,6 @@ else:
self.active = False
self.server.close()
- def bad_cert_test(certfile):
- """
- Launch a server with CERT_REQUIRED, and check that trying to
- connect to it with the given client certificate fails.
- """
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_REQUIRED,
- cacerts=CERTFILE, chatty=False,
- connectionchatty=False)
- with server:
- try:
- with socket.socket() as sock:
- s = ssl.wrap_socket(sock,
- certfile=certfile,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- except ssl.SSLError as x:
- if support.verbose:
- sys.stdout.write("\nSSLError is %s\n" % x.args[1])
- except OSError as x:
- if support.verbose:
- sys.stdout.write("\nOSError is %s\n" % x.args[1])
- except OSError as x:
- if x.errno != errno.ENOENT:
- raise
- if support.verbose:
- sys.stdout.write("\OSError is %s\n" % str(x))
- else:
- raise AssertionError("Use of invalid cert should have failed!")
-
def server_params_test(client_context, server_context, indata=b"FOO\n",
chatty=True, connectionchatty=False, sni_name=None):
"""
@@ -1893,14 +2153,25 @@ else:
'compression': s.compression(),
'cipher': s.cipher(),
'peercert': s.getpeercert(),
- 'client_npn_protocol': s.selected_npn_protocol()
+ 'client_alpn_protocol': s.selected_alpn_protocol(),
+ 'client_npn_protocol': s.selected_npn_protocol(),
+ 'version': s.version(),
})
s.close()
- stats['server_npn_protocols'] = server.selected_protocols
+ stats['server_alpn_protocols'] = server.selected_alpn_protocols
+ stats['server_npn_protocols'] = server.selected_npn_protocols
+ stats['server_shared_ciphers'] = server.shared_ciphers
return stats
def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0):
+ """
+ Try to SSL-connect using *client_protocol* to *server_protocol*.
+ If *expect_success* is true, assert that the connection succeeds,
+ if it's false, assert that the connection fails.
+ Also, if *expect_success* is a string, assert that it is the protocol
+ version actually used by the connection.
+ """
if certsreqs is None:
certsreqs = ssl.CERT_NONE
certtype = {
@@ -1930,8 +2201,8 @@ else:
ctx.load_cert_chain(CERTFILE)
ctx.load_verify_locations(CERTFILE)
try:
- server_params_test(client_context, server_context,
- chatty=False, connectionchatty=False)
+ stats = server_params_test(client_context, server_context,
+ chatty=False, connectionchatty=False)
# Protocol mismatch can result in either an SSLError, or a
# "Connection reset by peer" error.
except ssl.SSLError:
@@ -1946,6 +2217,10 @@ else:
"Client protocol %s succeeded with server protocol %s!"
% (ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol)))
+ elif (expect_success is not True
+ and expect_success != stats['version']):
+ raise AssertionError("version mismatch: expected %r, got %r"
+ % (expect_success, stats['version']))
class ThreadedTests(unittest.TestCase):
@@ -2081,22 +2356,38 @@ else:
"check_hostname requires server_hostname"):
context.wrap_socket(s)
- def test_empty_cert(self):
- """Connecting with an empty cert file"""
- bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
- "nullcert.pem"))
- def test_malformed_cert(self):
- """Connecting with a badly formatted certificate (syntax error)"""
- bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
- "badcert.pem"))
- def test_nonexisting_cert(self):
- """Connecting with a non-existing cert file"""
- bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
- "wrongcert.pem"))
- def test_malformed_key(self):
- """Connecting with a badly formatted key (syntax error)"""
- bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
- "badkey.pem"))
+ def test_wrong_cert(self):
+ """Connecting when the server rejects the client's certificate
+
+ Launch a server with CERT_REQUIRED, and check that trying to
+ connect to it with a wrong client certificate fails.
+ """
+ certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
+ "wrongcert.pem")
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_REQUIRED,
+ cacerts=CERTFILE, chatty=False,
+ connectionchatty=False)
+ with server, \
+ socket.socket() as sock, \
+ ssl.wrap_socket(sock,
+ certfile=certfile,
+ ssl_version=ssl.PROTOCOL_TLSv1) as s:
+ try:
+ # Expect either an SSL error about the server rejecting
+ # the connection, or a low-level connection reset (which
+ # sometimes happens on Windows)
+ s.connect((HOST, server.port))
+ except ssl.SSLError as e:
+ if support.verbose:
+ sys.stdout.write("\nSSLError is %r\n" % e)
+ except OSError as e:
+ if e.errno != errno.ECONNRESET:
+ raise
+ if support.verbose:
+ sys.stdout.write("\nsocket.error is %r\n" % e)
+ else:
+ self.fail("Use of invalid cert should have failed!")
def test_rude_shutdown(self):
"""A brutal shutdown of an SSL server should raise an OSError
@@ -2113,7 +2404,7 @@ else:
# and sets Event `listener_gone` to let the main thread know
# the socket is gone.
def listener():
- s.listen(5)
+ s.listen()
listener_ready.set()
newsock, addr = s.accept()
newsock.close()
@@ -2180,17 +2471,17 @@ else:
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
# Server with specific SSL options
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2210,9 +2501,9 @@ else:
"""Connecting to an SSLv3 server with various client options"""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
@@ -2228,9 +2519,9 @@ else:
"""Connecting to a TLSv1 server with various client options"""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2246,7 +2537,7 @@ else:
Testing against older TLS versions."""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2254,7 +2545,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1_1)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
@@ -2267,7 +2558,7 @@ else:
Testing against older TLS versions."""
if support.verbose:
sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True,
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
@@ -2277,7 +2568,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1_2)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
@@ -2502,6 +2793,13 @@ else:
# consume data
s.read()
+ # read(-1, buffer) is supported, even though read(-1) is not
+ data = b"data"
+ s.send(data)
+ buffer = bytearray(len(data))
+ self.assertEqual(s.read(-1, buffer), len(data))
+ self.assertEqual(buffer, data)
+
# Make sure sendmsg et al are disallowed to avoid
# inadvertent disclosure of data and/or corruption
# of the encrypted data stream
@@ -2511,6 +2809,60 @@ else:
s.recvmsg_into, bytearray(100))
s.write(b"over\n")
+
+ self.assertRaises(ValueError, s.recv, -1)
+ self.assertRaises(ValueError, s.read, -1)
+
+ s.close()
+
+ def test_recv_zero(self):
+ server = ThreadedEchoServer(CERTFILE)
+ server.__enter__()
+ self.addCleanup(server.__exit__, None, None)
+ s = socket.create_connection((HOST, server.port))
+ self.addCleanup(s.close)
+ s = ssl.wrap_socket(s, suppress_ragged_eofs=False)
+ self.addCleanup(s.close)
+
+ # recv/read(0) should return no data
+ s.send(b"data")
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.read(0), b"")
+ self.assertEqual(s.read(), b"data")
+
+ # Should not block if the other end sends no data
+ s.setblocking(False)
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.recv_into(bytearray()), 0)
+
+ def test_nonblocking_send(self):
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ s.setblocking(False)
+
+ # If we keep sending data, at some point the buffers
+ # will be full and the call will block
+ buf = bytearray(8192)
+ def fill_buffer():
+ while True:
+ s.send(buf)
+ self.assertRaises((ssl.SSLWantWriteError,
+ ssl.SSLWantReadError), fill_buffer)
+
+ # Now read all the output and discard it
+ s.setblocking(True)
s.close()
def test_handshake_timeout(self):
@@ -2522,7 +2874,7 @@ else:
finish = False
def serve():
- server.listen(5)
+ server.listen()
started.set()
conns = []
while not finish:
@@ -2579,7 +2931,7 @@ else:
peer = None
def serve():
nonlocal remote, peer
- server.listen(5)
+ server.listen()
# Block on the accept and wait on the connection to close.
evt.set()
remote, peer = server.accept()
@@ -2629,6 +2981,21 @@ else:
s.connect((HOST, server.port))
self.assertIn("no shared cipher", str(server.conn_errors[0]))
+ def test_version_basic(self):
+ """
+ Basic tests for SSLSocket.version().
+ More tests are done in the test_protocol_*() methods.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ self.assertIs(s.version(), None)
+ s.connect((HOST, server.port))
+ self.assertEqual(s.version(), "TLSv1")
+ self.assertIs(s.version(), None)
+
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
def test_default_ecdh_curve(self):
# Issue #21015: elliptic curve-based Diffie Hellman key exchange
@@ -2738,6 +3105,55 @@ else:
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
self.fail("Non-DH cipher: " + cipher[0])
+ def test_selected_alpn_protocol(self):
+ # selected_alpn_protocol() is None unless ALPN is used.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+ def test_selected_alpn_protocol_if_server_uses_alpn(self):
+ # selected_alpn_protocol() is None unless ALPN is used by the client.
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_verify_locations(CERTFILE)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(['foo', 'bar'])
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+ def test_alpn_protocols(self):
+ server_protocols = ['foo', 'bar', 'milkshake']
+ protocol_tests = [
+ (['foo', 'bar'], 'foo'),
+ (['bar', 'foo'], 'foo'),
+ (['milkshake'], 'milkshake'),
+ (['http/3.0', 'http/4.0'], None)
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_alpn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_alpn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_alpn_protocols'][-1] \
+ if len(stats['server_alpn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
@@ -2878,6 +3294,20 @@ else:
self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
self.assertIn("TypeError", stderr.getvalue())
+ def test_shared_ciphers(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ client_context.set_ciphers("RC4")
+ server_context.set_ciphers("AES:RC4")
+ stats = server_params_test(client_context, server_context)
+ ciphers = stats['server_shared_ciphers'][0]
+ self.assertGreater(len(ciphers), 0)
+ for name, tls_version, bits in ciphers:
+ self.assertIn("RC4", name.split("-"))
+
def test_read_write_after_close_raises_valuerror(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.verify_mode = ssl.CERT_REQUIRED
@@ -2893,21 +3323,46 @@ else:
self.assertRaises(ValueError, s.read, 1024)
self.assertRaises(ValueError, s.write, b'hello')
+ def test_sendfile(self):
+ TEST_DATA = b"x" * 512
+ with open(support.TESTFN, 'wb') as f:
+ f.write(TEST_DATA)
+ self.addCleanup(support.unlink, support.TESTFN)
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ with open(support.TESTFN, 'rb') as file:
+ s.sendfile(file)
+ self.assertEqual(s.recv(1024), TEST_DATA)
+
def test_main(verbose=False):
if support.verbose:
+ import warnings
plats = {
'Linux': platform.linux_distribution,
'Mac': platform.mac_ver,
'Windows': platform.win32_ver,
}
- for name, func in plats.items():
- plat = func()
- if plat and plat[0]:
- plat = '%s %r' % (name, plat)
- break
- else:
- plat = repr(platform.platform())
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ 'ignore',
+ 'dist\(\) and linux_distribution\(\) '
+ 'functions are deprecated .*',
+ PendingDeprecationWarning,
+ )
+ for name, func in plats.items():
+ plat = func()
+ if plat and plat[0]:
+ plat = '%s %r' % (name, plat)
+ break
+ else:
+ plat = repr(platform.platform())
print("test_ssl: testing with %r %r" %
(ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
print(" under %s" % plat)
@@ -2926,10 +3381,11 @@ def test_main(verbose=False):
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
- tests = [ContextTests, BasicSocketTests, SSLErrorTests]
+ tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
+ tests.append(NetworkedBIOTests)
if _have_threads:
thread_info = support.threading_setup()