summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r--Lib/test/test_ssl.py612
1 files changed, 523 insertions, 89 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 5c8df8b..e52a71c 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -42,6 +42,9 @@ ONLYCERT = data_file("ssl_cert.pem")
ONLYKEY = data_file("ssl_key.pem")
BYTES_ONLYCERT = os.fsencode(ONLYCERT)
BYTES_ONLYKEY = os.fsencode(ONLYKEY)
+CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
+ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
+KEY_PASSWORD = "somepass"
CAPATH = data_file("capath")
BYTES_CAPATH = os.fsencode(CAPATH)
@@ -52,7 +55,10 @@ BADCERT = data_file("badcert.pem")
WRONGCERT = data_file("XXXnonexisting.pem")
BADKEY = data_file("badkey.pem")
NOKIACERT = data_file("nokia.pem")
+NULLBYTECERT = data_file("nullbytecert.pem")
+DHFILE = data_file("dh512.pem")
+BYTES_DHFILE = os.fsencode(DHFILE)
def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
@@ -95,7 +101,14 @@ class BasicSocketTests(unittest.TestCase):
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
+ ssl.OP_CIPHER_SERVER_PREFERENCE
+ ssl.OP_SINGLE_DH_USE
+ if ssl.HAS_ECDH:
+ ssl.OP_SINGLE_ECDH_USE
+ if ssl.OPENSSL_VERSION_INFO >= (1, 0):
+ ssl.OP_NO_COMPRESSION
self.assertIn(ssl.HAS_SNI, {True, False})
+ self.assertIn(ssl.HAS_ECDH, {True, False})
def test_random(self):
v = ssl.RAND_status()
@@ -103,10 +116,56 @@ class BasicSocketTests(unittest.TestCase):
sys.stdout.write("\n RAND_status is %d (%s)\n"
% (v, (v and "sufficient randomness") or
"insufficient randomness"))
+
+ data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
+ self.assertEqual(len(data), 16)
+ self.assertEqual(is_cryptographic, v == 1)
+ if v:
+ data = ssl.RAND_bytes(16)
+ self.assertEqual(len(data), 16)
+ else:
+ self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
+
+ # negative num is invalid
+ self.assertRaises(ValueError, ssl.RAND_bytes, -5)
+ self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
+
self.assertRaises(TypeError, ssl.RAND_egd, 1)
self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
ssl.RAND_add("this is a random string", 75.0)
+ @unittest.skipUnless(os.name == 'posix', 'requires posix')
+ def test_random_fork(self):
+ status = ssl.RAND_status()
+ if not status:
+ self.fail("OpenSSL's PRNG has insufficient randomness")
+
+ rfd, wfd = os.pipe()
+ pid = os.fork()
+ if pid == 0:
+ try:
+ os.close(rfd)
+ child_random = ssl.RAND_pseudo_bytes(16)[0]
+ self.assertEqual(len(child_random), 16)
+ os.write(wfd, child_random)
+ os.close(wfd)
+ except BaseException:
+ os._exit(1)
+ else:
+ os._exit(0)
+ else:
+ os.close(wfd)
+ self.addCleanup(os.close, rfd)
+ _, status = os.waitpid(pid, 0)
+ self.assertEqual(status, 0)
+
+ child_random = os.read(rfd, 16)
+ self.assertEqual(len(child_random), 16)
+ parent_random = ssl.RAND_pseudo_bytes(16)[0]
+ self.assertEqual(len(parent_random), 16)
+
+ self.assertNotEqual(child_random, parent_random)
+
def test_parse_cert(self):
# note that this uses an 'unofficial' function in _ssl.c,
# provided solely for this test, to exercise the certificate
@@ -140,6 +199,35 @@ class BasicSocketTests(unittest.TestCase):
('DNS', 'projects.forum.nokia.com'))
)
+ def test_parse_cert_CVE_2013_4238(self):
+ p = ssl._ssl._test_decode_cert(NULLBYTECERT)
+ if support.verbose:
+ sys.stdout.write("\n" + pprint.pformat(p) + "\n")
+ subject = ((('countryName', 'US'),),
+ (('stateOrProvinceName', 'Oregon'),),
+ (('localityName', 'Beaverton'),),
+ (('organizationName', 'Python Software Foundation'),),
+ (('organizationalUnitName', 'Python Core Development'),),
+ (('commonName', 'null.python.org\x00example.org'),),
+ (('emailAddress', 'python-dev@python.org'),))
+ self.assertEqual(p['subject'], subject)
+ self.assertEqual(p['issuer'], subject)
+ if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
+ san = (('DNS', 'altnull.python.org\x00example.com'),
+ ('email', 'null@python.org\x00user@example.org'),
+ ('URI', 'http://null.python.org\x00http://example.org'),
+ ('IP Address', '192.0.2.1'),
+ ('IP Address', '2001:DB8:0:0:0:0:0:1\n'))
+ else:
+ # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
+ san = (('DNS', 'altnull.python.org\x00example.com'),
+ ('email', 'null@python.org\x00user@example.org'),
+ ('URI', 'http://null.python.org\x00http://example.org'),
+ ('IP Address', '192.0.2.1'),
+ ('IP Address', '<invalid>'))
+
+ self.assertEqual(p['subjectAltName'], san)
+
def test_DER_to_PEM(self):
with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f:
pem = f.read()
@@ -186,20 +274,21 @@ class BasicSocketTests(unittest.TestCase):
s = socket.socket(socket.AF_INET)
ss = ssl.wrap_socket(s)
wr = weakref.ref(ss)
- del ss
- self.assertEqual(wr(), None)
+ with support.check_warnings(("", ResourceWarning)):
+ del ss
+ self.assertEqual(wr(), None)
def test_wrapped_unconnected(self):
# Methods on an unconnected SSLSocket propagate the original
# socket.error raise by the underlying socket object.
s = socket.socket(socket.AF_INET)
- ss = ssl.wrap_socket(s)
- self.assertRaises(socket.error, ss.recv, 1)
- self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
- self.assertRaises(socket.error, ss.recvfrom, 1)
- self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
- self.assertRaises(socket.error, ss.send, b'x')
- self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
+ with ssl.wrap_socket(s) as ss:
+ self.assertRaises(socket.error, ss.recv, 1)
+ self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
+ self.assertRaises(socket.error, ss.recvfrom, 1)
+ self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
+ self.assertRaises(socket.error, ss.send, b'x')
+ self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
@@ -207,8 +296,8 @@ class BasicSocketTests(unittest.TestCase):
for timeout in (None, 0.0, 5.0):
s = socket.socket(socket.AF_INET)
s.settimeout(timeout)
- ss = ssl.wrap_socket(s)
- self.assertEqual(timeout, ss.gettimeout())
+ with ssl.wrap_socket(s) as ss:
+ self.assertEqual(timeout, ss.gettimeout())
def test_errors(self):
sock = socket.socket()
@@ -221,9 +310,9 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaisesRegex(ValueError,
"certfile must be specified for server-side operations",
ssl.wrap_socket, sock, server_side=True, certfile="")
- s = ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE)
- self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
- s.connect, (HOST, 8080))
+ with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
+ self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
+ s.connect, (HOST, 8080))
with self.assertRaises(IOError) as cm:
with socket.socket() as sock:
ssl.wrap_socket(sock, certfile=WRONGCERT)
@@ -259,11 +348,7 @@ class BasicSocketTests(unittest.TestCase):
fail(cert, 'Xa.com')
fail(cert, '.a.com')
- cert = {'subject': ((('commonName', 'a.*.com'),),)}
- ok(cert, 'a.foo.com')
- fail(cert, 'a..com')
- fail(cert, 'a.com')
-
+ # only match one left-most wildcard
cert = {'subject': ((('commonName', 'f*.com'),),)}
ok(cert, 'foo.com')
ok(cert, 'f.com')
@@ -271,6 +356,43 @@ class BasicSocketTests(unittest.TestCase):
fail(cert, 'foo.a.com')
fail(cert, 'bar.foo.com')
+ # NULL bytes are bad, CVE-2013-4073
+ cert = {'subject': ((('commonName',
+ 'null.python.org\x00example.org'),),)}
+ ok(cert, 'null.python.org\x00example.org') # or raise an error?
+ fail(cert, 'example.org')
+ fail(cert, 'null.python.org')
+
+ # error cases with wildcards
+ cert = {'subject': ((('commonName', '*.*.a.com'),),)}
+ fail(cert, 'bar.foo.a.com')
+ fail(cert, 'a.com')
+ fail(cert, 'Xa.com')
+ fail(cert, '.a.com')
+
+ cert = {'subject': ((('commonName', 'a.*.com'),),)}
+ fail(cert, 'a.foo.com')
+ fail(cert, 'a..com')
+ fail(cert, 'a.com')
+
+ # wildcard doesn't match IDNA prefix 'xn--'
+ idna = 'püthon.python.org'.encode("idna").decode("ascii")
+ cert = {'subject': ((('commonName', idna),),)}
+ ok(cert, idna)
+ cert = {'subject': ((('commonName', 'x*.python.org'),),)}
+ fail(cert, idna)
+ cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
+ fail(cert, idna)
+
+ # wildcard in first fragment and IDNA A-labels in sequent fragments
+ # are supported.
+ idna = 'www*.pythön.org'.encode("idna").decode("ascii")
+ cert = {'subject': ((('commonName', idna),),)}
+ ok(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
+ ok(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
+ fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
+ fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
+
# Slightly fake real-world example
cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
'subject': ((('commonName', 'linuxfrz.org'),),),
@@ -331,7 +453,7 @@ class BasicSocketTests(unittest.TestCase):
cert = {'subject': ((('commonName', 'a*b.com'),),)}
ok(cert, 'axxb.com')
cert = {'subject': ((('commonName', 'a*b.co*'),),)}
- ok(cert, 'axxb.com')
+ fail(cert, 'axxb.com')
cert = {'subject': ((('commonName', 'a*b*.com'),),)}
with self.assertRaises(ssl.CertificateError) as cm:
ssl.match_hostname(cert, 'axxbxxc.com')
@@ -344,6 +466,45 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
server_hostname="some.hostname")
+ def test_unknown_channel_binding(self):
+ # should raise ValueError for unknown type
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s) as ss:
+ with self.assertRaises(ValueError):
+ ss.get_channel_binding("unknown-type")
+
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ # unconnected should return None for known type
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s) as ss:
+ self.assertIsNone(ss.get_channel_binding("tls-unique"))
+ # the same for server-side
+ s = socket.socket(socket.AF_INET)
+ with ssl.wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
+ self.assertIsNone(ss.get_channel_binding("tls-unique"))
+
+ def test_dealloc_warn(self):
+ ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
+ r = repr(ss)
+ with self.assertWarns(ResourceWarning) as cm:
+ ss = None
+ support.gc_collect()
+ self.assertIn(r, str(cm.warning.args[0]))
+
+ def test_unsupported_dtls(self):
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.addCleanup(s.close)
+ with self.assertRaises(NotImplementedError) as cx:
+ ssl.wrap_socket(s, cert_reqs=ssl.CERT_NONE)
+ self.assertEqual(str(cx.exception), "only stream sockets are supported")
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with self.assertRaises(NotImplementedError) as cx:
+ ctx.wrap_socket(s)
+ self.assertEqual(str(cx.exception), "only stream sockets are supported")
+
+
class ContextTests(unittest.TestCase):
@skip_if_broken_ubuntu_ssl
@@ -373,9 +534,7 @@ class ContextTests(unittest.TestCase):
@skip_if_broken_ubuntu_ssl
def test_options(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- # OP_ALL is the default value
- self.assertEqual(ssl.OP_ALL, ctx.options)
- ctx.options |= ssl.OP_NO_SSLv2
+ # OP_ALL | OP_NO_SSLv2 is the default value
self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2,
ctx.options)
ctx.options |= ssl.OP_NO_SSLv3
@@ -434,6 +593,60 @@ class ContextTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
ctx.load_cert_chain(SVN_PYTHON_ORG_ROOT_CERT, ONLYKEY)
+ # Password protected key and cert
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
+ ctx.load_cert_chain(CERTFILE_PROTECTED,
+ password=bytearray(KEY_PASSWORD.encode()))
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
+ ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
+ bytearray(KEY_PASSWORD.encode()))
+ with self.assertRaisesRegex(TypeError, "should be a string"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
+ with self.assertRaises(ssl.SSLError):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
+ with self.assertRaisesRegex(ValueError, "cannot be longer"):
+ # openssl has a fixed limit on the password buffer.
+ # PEM_BUFSIZE is generally set to 1kb.
+ # Return a string larger than this.
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
+ # Password callback
+ def getpass_unicode():
+ return KEY_PASSWORD
+ def getpass_bytes():
+ return KEY_PASSWORD.encode()
+ def getpass_bytearray():
+ return bytearray(KEY_PASSWORD.encode())
+ def getpass_badpass():
+ return "badpass"
+ def getpass_huge():
+ return b'a' * (1024 * 1024)
+ def getpass_bad_type():
+ return 9
+ def getpass_exception():
+ raise Exception('getpass error')
+ class GetPassCallable:
+ def __call__(self):
+ return KEY_PASSWORD
+ def getpass(self):
+ return KEY_PASSWORD
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
+ ctx.load_cert_chain(CERTFILE_PROTECTED,
+ password=GetPassCallable().getpass)
+ with self.assertRaises(ssl.SSLError):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
+ with self.assertRaisesRegex(ValueError, "cannot be longer"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
+ with self.assertRaisesRegex(TypeError, "must return a string"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
+ with self.assertRaisesRegex(Exception, "getpass error"):
+ ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
+ # Make sure the password function isn't called if it isn't needed
+ ctx.load_cert_chain(CERTFILE, password=getpass_exception)
def test_load_verify_locations(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
@@ -454,6 +667,19 @@ class ContextTests(unittest.TestCase):
# Issue #10989: crash if the second argument type is invalid
self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
+ def test_load_dh_params(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.load_dh_params(DHFILE)
+ if os.name != 'nt':
+ ctx.load_dh_params(BYTES_DHFILE)
+ self.assertRaises(TypeError, ctx.load_dh_params)
+ self.assertRaises(TypeError, ctx.load_dh_params, None)
+ with self.assertRaises(FileNotFoundError) as cm:
+ ctx.load_dh_params(WRONGCERT)
+ self.assertEqual(cm.exception.errno, errno.ENOENT)
+ with self.assertRaises(ssl.SSLError) as cm:
+ ctx.load_dh_params(CERTFILE)
+
@skip_if_broken_ubuntu_ssl
def test_session_stats(self):
for proto in PROTOCOLS:
@@ -478,6 +704,57 @@ class ContextTests(unittest.TestCase):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.set_default_verify_paths()
+ @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
+ def test_set_ecdh_curve(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.set_ecdh_curve("prime256v1")
+ ctx.set_ecdh_curve(b"prime256v1")
+ self.assertRaises(TypeError, ctx.set_ecdh_curve)
+ self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
+ self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
+ self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
+
+
+class SSLErrorTests(unittest.TestCase):
+
+ def test_str(self):
+ # The str() of a SSLError doesn't include the errno
+ e = ssl.SSLError(1, "foo")
+ self.assertEqual(str(e), "foo")
+ self.assertEqual(e.errno, 1)
+ # Same for a subclass
+ e = ssl.SSLZeroReturnError(1, "foo")
+ self.assertEqual(str(e), "foo")
+ self.assertEqual(e.errno, 1)
+
+ def test_lib_reason(self):
+ # Test the library and reason attributes
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with self.assertRaises(ssl.SSLError) as cm:
+ ctx.load_dh_params(CERTFILE)
+ self.assertEqual(cm.exception.library, 'PEM')
+ self.assertEqual(cm.exception.reason, 'NO_START_LINE')
+ s = str(cm.exception)
+ self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
+
+ def test_subclass(self):
+ # Check that the appropriate SSLError subclass is raised
+ # (this only tests one of them)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with socket.socket() as s:
+ s.bind(("127.0.0.1", 0))
+ s.listen(5)
+ c = socket.socket()
+ c.connect(s.getsockname())
+ c.setblocking(False)
+ with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
+ with self.assertRaises(ssl.SSLWantReadError) as cm:
+ c.do_handshake()
+ s = str(cm.exception)
+ self.assertTrue(s.startswith("The operation did not complete (read)"), s)
+ # For compatibility
+ self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
+
class NetworkedTests(unittest.TestCase):
@@ -540,13 +817,10 @@ class NetworkedTests(unittest.TestCase):
try:
s.do_handshake()
break
- except ssl.SSLError as err:
- if err.args[0] == ssl.SSL_ERROR_WANT_READ:
- select.select([s], [], [], 5.0)
- elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
- select.select([], [s], [], 5.0)
- else:
- raise
+ except ssl.SSLWantReadError:
+ select.select([s], [], [], 5.0)
+ except ssl.SSLWantWriteError:
+ select.select([], [s], [], 5.0)
# SSL established
self.assertTrue(s.getpeercert())
finally:
@@ -575,8 +849,10 @@ class NetworkedTests(unittest.TestCase):
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
try:
- self.assertEqual(errno.ECONNREFUSED,
- s.connect_ex(("svn.python.org", 444)))
+ rc = s.connect_ex(("svn.python.org", 444))
+ # Issue #19919: Windows machines or VMs hosted on Windows
+ # machines sometimes return EWOULDBLOCK.
+ self.assertIn(rc, (errno.ECONNREFUSED, errno.EWOULDBLOCK))
finally:
s.close()
@@ -677,52 +953,54 @@ class NetworkedTests(unittest.TestCase):
count += 1
s.do_handshake()
break
- except ssl.SSLError as err:
- if err.args[0] == ssl.SSL_ERROR_WANT_READ:
- select.select([s], [], [])
- elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
- select.select([], [s], [])
- else:
- raise
+ except ssl.SSLWantReadError:
+ select.select([s], [], [])
+ except ssl.SSLWantWriteError:
+ select.select([], [s], [])
s.close()
if support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def test_get_server_certificate(self):
- with support.transient_internet("svn.python.org"):
- pem = ssl.get_server_certificate(("svn.python.org", 443),
- ssl.PROTOCOL_SSLv23)
- if not pem:
- self.fail("No server certificate on svn.python.org:443!")
+ def _test_get_server_certificate(host, port, cert=None):
+ with support.transient_internet(host):
+ pem = ssl.get_server_certificate((host, port),
+ ssl.PROTOCOL_SSLv23)
+ if not pem:
+ self.fail("No server certificate on %s:%s!" % (host, port))
- try:
- pem = ssl.get_server_certificate(("svn.python.org", 443),
+ try:
+ pem = ssl.get_server_certificate((host, port),
+ ssl.PROTOCOL_SSLv23,
+ ca_certs=CERTFILE)
+ except ssl.SSLError as x:
+ #should fail
+ if support.verbose:
+ sys.stdout.write("%s\n" % x)
+ else:
+ self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
+
+ pem = ssl.get_server_certificate((host, port),
ssl.PROTOCOL_SSLv23,
- ca_certs=CERTFILE)
- except ssl.SSLError as x:
- #should fail
+ ca_certs=cert)
+ if not pem:
+ self.fail("No server certificate on %s:%s!" % (host, port))
if support.verbose:
- sys.stdout.write("%s\n" % x)
- else:
- self.fail("Got server certificate %s for svn.python.org!" % pem)
+ sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
- pem = ssl.get_server_certificate(("svn.python.org", 443),
- ssl.PROTOCOL_SSLv23,
- ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
- if not pem:
- self.fail("No server certificate on svn.python.org:443!")
- if support.verbose:
- sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
+ _test_get_server_certificate('svn.python.org', 443, SVN_PYTHON_ORG_ROOT_CERT)
+ if support.IPV6_ENABLED:
+ _test_get_server_certificate('ipv6.google.com', 443)
def test_ciphers(self):
remote = ("svn.python.org", 443)
with support.transient_internet(remote[0]):
- s = ssl.wrap_socket(socket.socket(socket.AF_INET),
- cert_reqs=ssl.CERT_NONE, ciphers="ALL")
- s.connect(remote)
- s = ssl.wrap_socket(socket.socket(socket.AF_INET),
- cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
- s.connect(remote)
+ with ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
+ s.connect(remote)
+ with ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
+ s.connect(remote)
# Error checking can happen at instantiation or when connecting
with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
with socket.socket(socket.AF_INET) as sock:
@@ -790,13 +1068,12 @@ else:
try:
self.sslconn = self.server.context.wrap_socket(
self.sock, server_side=True)
- except (ssl.SSLError, socket.error) as e:
- # Treat ECONNRESET as though it were an SSLError - OpenSSL
- # on Ubuntu abruptly closes the connection when asked to use
- # an unsupported protocol.
- if (not isinstance(e, ssl.SSLError) and
- e.errno != errno.ECONNRESET):
- raise
+ self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
+ except (ssl.SSLError, ConnectionResetError) as e:
+ # We treat ConnectionResetError as though it were an
+ # SSLError - OpenSSL on Ubuntu abruptly closes the
+ # connection when asked to use an unsupported protocol.
+ #
# XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate,
# or a low-level bug. This should be made more discriminating.
@@ -818,6 +1095,8 @@ else:
cipher = self.sslconn.cipher()
if support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+ sys.stdout.write(" server: selected protocol is now "
+ + str(self.sslconn.selected_npn_protocol()) + "\n")
return True
def read(self):
@@ -872,6 +1151,11 @@ else:
self.sslconn = None
if support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: connection is now unencrypted...\n")
+ elif stripped == b'CB tls-unique':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
+ data = self.sslconn.get_channel_binding("tls-unique")
+ self.write(repr(data).encode("us-ascii") + b"\n")
else:
if (support.verbose and
self.server.connectionchatty):
@@ -891,7 +1175,7 @@ else:
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
- ciphers=None, context=None):
+ npn_protocols=None, ciphers=None, context=None):
if context:
self.context = context
else:
@@ -904,6 +1188,8 @@ else:
self.context.load_verify_locations(cacerts)
if certificate:
self.context.load_cert_chain(certificate)
+ if npn_protocols:
+ self.context.set_npn_protocols(npn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.chatty = chatty
@@ -913,6 +1199,7 @@ else:
self.port = support.bind_port(self.sock)
self.flag = None
self.active = False
+ self.selected_protocols = []
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
@@ -980,12 +1267,11 @@ else:
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
- except ssl.SSLError as err:
- if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE):
- return
- elif err.args[0] == ssl.SSL_ERROR_EOF:
- return self.handle_close()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ return
+ except ssl.SSLEOFError:
+ return self.handle_close()
+ except ssl.SSLError:
raise
except socket.error as err:
if err.args[0] == errno.ECONNABORTED:
@@ -1108,6 +1394,7 @@ else:
Launch a server, connect a client to it and try various reads
and writes.
"""
+ stats = {}
server = ThreadedEchoServer(context=server_context,
chatty=chatty,
connectionchatty=False)
@@ -1133,7 +1420,14 @@ else:
if connectionchatty:
if support.verbose:
sys.stdout.write(" client: closing connection.\n")
+ stats.update({
+ 'compression': s.compression(),
+ 'cipher': s.cipher(),
+ 'client_npn_protocol': s.selected_npn_protocol()
+ })
s.close()
+ stats['server_npn_protocols'] = server.selected_protocols
+ return stats
def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0):
@@ -1151,9 +1445,9 @@ else:
ssl.get_protocol_name(server_protocol),
certtype))
client_context = ssl.SSLContext(client_protocol)
- client_context.options = ssl.OP_ALL | client_options
+ client_context.options |= client_options
server_context = ssl.SSLContext(server_protocol)
- server_context.options = ssl.OP_ALL | server_options
+ server_context.options |= server_options
for ctx in (client_context, server_context):
ctx.verify_mode = certsreqs
# NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
@@ -1285,7 +1579,8 @@ else:
t.join()
@skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), "need SSLv2")
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
+ "OpenSSL is compiled without SSLv2 support")
def test_protocol_sslv2(self):
"""Connecting to an SSLv2 server with various client options"""
if support.verbose:
@@ -1293,7 +1588,7 @@ else:
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
# SSLv23 client with specific SSL options
@@ -1301,9 +1596,9 @@ else:
# No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_SSLv2)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True,
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True,
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1)
@skip_if_broken_ubuntu_ssl
@@ -1591,6 +1886,15 @@ else:
)
# consume data
s.read()
+
+ # Make sure sendmsg et al are disallowed to avoid
+ # inadvertent disclosure of data and/or corruption
+ # of the encrypted data stream
+ self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
+ self.assertRaises(NotImplementedError, s.recvmsg, 100)
+ self.assertRaises(NotImplementedError,
+ s.recvmsg_into, bytearray(100))
+
s.write(b"over\n")
s.close()
@@ -1675,6 +1979,8 @@ else:
client_addr = client.getsockname()
client.close()
t.join()
+ remote.close()
+ server.close()
# Sanity checks.
self.assertIsInstance(remote, ssl.SSLSocket)
self.assertEqual(peer, client_addr)
@@ -1689,12 +1995,140 @@ else:
with ThreadedEchoServer(CERTFILE,
ssl_version=ssl.PROTOCOL_SSLv23,
chatty=False) as server:
- with socket.socket() as sock:
- s = context.wrap_socket(sock)
+ with context.wrap_socket(socket.socket()) as s:
with self.assertRaises((OSError, ssl.SSLError)):
s.connect((HOST, server.port))
self.assertIn("no shared cipher", str(server.conn_errors[0]))
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ """Test tls-unique channel binding."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # get the data
+ cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got channel binding data: {0!r}\n"
+ .format(cb_data))
+
+ # check if it is sane
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+
+ # and compare with the peers version
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(cb_data).encode("us-ascii"))
+ s.close()
+
+ # now, again
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ new_cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got another channel binding data: {0!r}\n"
+ .format(new_cb_data))
+ # is it really unique
+ self.assertNotEqual(cb_data, new_cb_data)
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(new_cb_data).encode("us-ascii"))
+ s.close()
+
+ def test_compression(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ if support.verbose:
+ sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
+ self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
+
+ @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
+ "ssl.OP_NO_COMPRESSION needed for this test")
+ def test_compression_disabled(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.options |= ssl.OP_NO_COMPRESSION
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['compression'], None)
+
+ def test_dh_params(self):
+ # Check we can get a connection with ephemeral Diffie-Hellman
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.load_dh_params(DHFILE)
+ context.set_ciphers("kEDH")
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ cipher = stats["cipher"][0]
+ parts = cipher.split("-")
+ if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
+ self.fail("Non-DH cipher: " + cipher[0])
+
+ def test_selected_npn_protocol(self):
+ # selected_npn_protocol() is None unless NPN is used
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_npn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
+ def test_npn_protocols(self):
+ server_protocols = ['http/1.1', 'spdy/2']
+ protocol_tests = [
+ (['http/1.1', 'spdy/2'], 'http/1.1'),
+ (['spdy/2', 'http/1.1'], 'http/1.1'),
+ (['spdy/2', 'test'], 'spdy/2'),
+ (['abc', 'def'], 'abc')
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_npn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_npn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_npn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_npn_protocols'][-1] \
+ if len(stats['server_npn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
def test_main(verbose=False):
if support.verbose:
@@ -1722,14 +2156,14 @@ def test_main(verbose=False):
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
- tests = [ContextTests, BasicSocketTests]
+ tests = [ContextTests, BasicSocketTests, SSLErrorTests]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
if _have_threads:
thread_info = support.threading_setup()
- if thread_info and support.is_resource_enabled('network'):
+ if thread_info:
tests.append(ThreadedTests)
try: