summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBill Janssen <janssen@parc.com>2008-06-28 22:19:33 (GMT)
committerBill Janssen <janssen@parc.com>2008-06-28 22:19:33 (GMT)
commit934b16d0c2c4dcaa15051e4e7d61543f9f64fa82 (patch)
tree53b3eb297a86932d5966e2a540959a574f8ef02d
parenta27474c345becd19e2d39a2265cbcd31667df3f6 (diff)
downloadcpython-934b16d0c2c4dcaa15051e4e7d61543f9f64fa82.zip
cpython-934b16d0c2c4dcaa15051e4e7d61543f9f64fa82.tar.gz
cpython-934b16d0c2c4dcaa15051e4e7d61543f9f64fa82.tar.bz2
various SSL fixes; issues 1251, 3162, 3212
-rw-r--r--Doc/library/ssl.rst34
-rw-r--r--Lib/ssl.py361
-rw-r--r--Lib/test/test_ssl.py246
-rw-r--r--Lib/test/wrongcert.pem32
-rw-r--r--Modules/_ssl.c203
5 files changed, 528 insertions, 348 deletions
diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst
index fb41091..a41c6ea 100644
--- a/Doc/library/ssl.rst
+++ b/Doc/library/ssl.rst
@@ -54,7 +54,7 @@ Functions, Constants, and Exceptions
network connection. This error is a subtype of :exc:`socket.error`, which
in turn is a subtype of :exc:`IOError`.
-.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None)
+.. function:: wrap_socket (sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version={see docs}, ca_certs=None, do_handshake_on_connect=True, suppress_ragged_eofs=True)
Takes an instance ``sock`` of :class:`socket.socket`, and returns an instance of :class:`ssl.SSLSocket`, a subtype
of :class:`socket.socket`, which wraps the underlying socket in an SSL context.
@@ -122,6 +122,18 @@ Functions, Constants, and Exceptions
In some older versions of OpenSSL (for instance, 0.9.7l on OS X 10.4),
an SSLv2 client could not connect to an SSLv23 server.
+ The parameter ``do_handshake_on_connect`` specifies whether to do the SSL
+ handshake automatically after doing a :meth:`socket.connect`, or whether the
+ application program will call it explicitly, by invoking the :meth:`SSLSocket.do_handshake`
+ method. Calling :meth:`SSLSocket.do_handshake` explicitly gives the program control over
+ the blocking behavior of the socket I/O involved in the handshake.
+
+ The parameter ``suppress_ragged_eofs`` specifies how the :meth:`SSLSocket.read`
+ method should signal unexpected EOF from the other end of the connection. If specified
+ as :const:`True` (the default), it returns a normal EOF in response to unexpected
+ EOF errors raised from the underlying socket; if :const:`False`, it will raise
+ the exceptions back to the caller.
+
.. function:: RAND_status()
Returns True if the SSL pseudo-random number generator has been
@@ -290,6 +302,25 @@ SSLSocket Objects
number of secret bits being used. If no connection has been
established, returns ``None``.
+.. method:: SSLSocket.do_handshake()
+
+ Perform a TLS/SSL handshake. If this is used with a non-blocking socket,
+ it may raise :exc:`SSLError` with an ``arg[0]`` of :const:`SSL_ERROR_WANT_READ`
+ or :const:`SSL_ERROR_WANT_WRITE`, in which case it must be called again until it
+ completes successfully. For example, to simulate the behavior of a blocking socket,
+ one might write::
+
+ while True:
+ try:
+ s.do_handshake()
+ break
+ except ssl.SSLError, err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ select.select([s], [], [])
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ select.select([], [s], [])
+ else:
+ raise
.. index:: single: certificates
@@ -367,6 +398,7 @@ certificate, you need to provide a "CA certs" file, filled with the certificate
chains for each issuer you are willing to trust. Again, this file just
contains these chains concatenated together. For validation, Python will
use the first chain it finds in the file which matches.
+
Some "standard" root certificates are available from various certification
authorities:
`CACert.org <http://www.cacert.org/index.php?id=3>`_,
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 24502e4..e45e16b 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -74,7 +74,7 @@ from _ssl import \
SSL_ERROR_EOF, \
SSL_ERROR_INVALID_ERROR_CODE
-from socket import socket
+from socket import socket, _fileobject
from socket import getnameinfo as _getnameinfo
import base64 # for DER-to-PEM translation
@@ -86,8 +86,16 @@ class SSLSocket (socket):
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
- ssl_version=PROTOCOL_SSLv23, ca_certs=None):
+ ssl_version=PROTOCOL_SSLv23, ca_certs=None,
+ do_handshake_on_connect=True,
+ suppress_ragged_eofs=True):
socket.__init__(self, _sock=sock._sock)
+ # the initializer for socket trashes the methods (tsk, tsk), so...
+ self.send = lambda x, flags=0: SSLSocket.send(self, x, flags)
+ self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags)
+ self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
+ self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags)
+
if certfile and not keyfile:
keyfile = certfile
# see if it's connected
@@ -101,18 +109,34 @@ class SSLSocket (socket):
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
+ if do_handshake_on_connect:
+ timeout = self.gettimeout()
+ try:
+ self.settimeout(None)
+ self.do_handshake()
+ finally:
+ self.settimeout(timeout)
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
+ self.do_handshake_on_connect = do_handshake_on_connect
+ self.suppress_ragged_eofs = suppress_ragged_eofs
+ self._makefile_refs = 0
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
- return self._sslobj.read(len)
+ try:
+ return self._sslobj.read(len)
+ except SSLError, x:
+ if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
+ return ''
+ else:
+ raise
def write(self, data):
@@ -143,16 +167,27 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
- return self._sslobj.write(data)
+ while True:
+ try:
+ v = self._sslobj.write(data)
+ except SSLError, x:
+ if x.args[0] == SSL_ERROR_WANT_READ:
+ return 0
+ elif x.args[0] == SSL_ERROR_WANT_WRITE:
+ return 0
+ else:
+ raise
+ else:
+ return v
else:
return socket.send(self, data, flags)
- def send_to (self, data, addr, flags=0):
+ def sendto (self, data, addr, flags=0):
if self._sslobj:
- raise ValueError("send_to not allowed on instances of %s" %
+ raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
- return socket.send_to(self, data, addr, flags)
+ return socket.sendto(self, data, addr, flags)
def sendall (self, data, flags=0):
if self._sslobj:
@@ -160,7 +195,12 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
- return self._sslobj.write(data)
+ amount = len(data)
+ count = 0
+ while (count < amount):
+ v = self.send(data[count:])
+ count += v
+ return amount
else:
return socket.sendall(self, data, flags)
@@ -170,25 +210,51 @@ class SSLSocket (socket):
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
- return self._sslobj.read(data, buflen)
+ while True:
+ try:
+ return self.read(buflen)
+ except SSLError, x:
+ if x.args[0] == SSL_ERROR_WANT_READ:
+ continue
+ else:
+ raise x
else:
return socket.recv(self, buflen, flags)
- def recv_from (self, addr, buflen=1024, flags=0):
+ def recvfrom (self, addr, buflen=1024, flags=0):
if self._sslobj:
- raise ValueError("recv_from not allowed on instances of %s" %
+ raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
- return socket.recv_from(self, addr, buflen, flags)
+ return socket.recvfrom(self, addr, buflen, flags)
- def shutdown(self, how):
+ def pending (self):
+ if self._sslobj:
+ return self._sslobj.pending()
+ else:
+ return 0
+
+ def shutdown (self, how):
self._sslobj = None
socket.shutdown(self, how)
- def close(self):
+ def close (self):
self._sslobj = None
socket.close(self)
+ def close (self):
+ if self._makefile_refs < 1:
+ self._sslobj = None
+ socket.close(self)
+ else:
+ self._makefile_refs -= 1
+
+ def do_handshake (self):
+
+ """Perform a TLS/SSL handshake."""
+
+ self._sslobj.do_handshake()
+
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
@@ -202,6 +268,8 @@ class SSLSocket (socket):
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
+ if self.do_handshake_on_connect:
+ self.do_handshake()
def accept(self):
@@ -210,260 +278,39 @@ class SSLSocket (socket):
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
- return (SSLSocket(newsock, True, self.keyfile, self.certfile,
- self.cert_reqs, self.ssl_version,
- self.ca_certs), addr)
-
+ return (SSLSocket(newsock,
+ keyfile=self.keyfile,
+ certfile=self.certfile,
+ server_side=True,
+ cert_reqs=self.cert_reqs,
+ ssl_version=self.ssl_version,
+ ca_certs=self.ca_certs,
+ do_handshake_on_connect=self.do_handshake_on_connect,
+ suppress_ragged_eofs=self.suppress_ragged_eofs),
+ addr)
def makefile(self, mode='r', bufsize=-1):
"""Ouch. Need to make and return a file-like object that
works with the SSL connection."""
- if self._sslobj:
- return SSLFileStream(self._sslobj, mode, bufsize)
- else:
- return socket.makefile(self, mode, bufsize)
-
-
-class SSLFileStream:
-
- """A class to simulate a file stream on top of a socket.
- Most of this is just lifted from the socket module, and
- adjusted to work with an SSL stream instead of a socket."""
-
-
- default_bufsize = 8192
- name = "<SSL stream>"
-
- __slots__ = ["mode", "bufsize", "softspace",
- # "closed" is a property, see below
- "_sslobj", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
- "_close", "_fileno"]
-
- def __init__(self, sslobj, mode='rb', bufsize=-1, close=False):
- self._sslobj = sslobj
- self.mode = mode # Not actually used in this version
- if bufsize < 0:
- bufsize = self.default_bufsize
- self.bufsize = bufsize
- self.softspace = False
- if bufsize == 0:
- self._rbufsize = 1
- elif bufsize == 1:
- self._rbufsize = self.default_bufsize
- else:
- self._rbufsize = bufsize
- self._wbufsize = bufsize
- self._rbuf = "" # A string
- self._wbuf = [] # A list of strings
- self._close = close
- self._fileno = -1
-
- def _getclosed(self):
- return self._sslobj is None
- closed = property(_getclosed, doc="True if the file is closed")
-
- def fileno(self):
- return self._fileno
-
- def close(self):
- try:
- if self._sslobj:
- self.flush()
- finally:
- if self._close and self._sslobj:
- self._sslobj.close()
- self._sslobj = None
-
- def __del__(self):
- try:
- self.close()
- except:
- # close() may fail if __init__ didn't complete
- pass
-
- def flush(self):
- if self._wbuf:
- buffer = "".join(self._wbuf)
- self._wbuf = []
- count = 0
- while (count < len(buffer)):
- written = self._sslobj.write(buffer)
- count += written
- buffer = buffer[written:]
-
- def write(self, data):
- data = str(data) # XXX Should really reject non-string non-buffers
- if not data:
- return
- self._wbuf.append(data)
- if (self._wbufsize == 0 or
- self._wbufsize == 1 and '\n' in data or
- self._get_wbuf_len() >= self._wbufsize):
- self.flush()
-
- def writelines(self, list):
- # XXX We could do better here for very long lists
- # XXX Should really reject non-string non-buffers
- self._wbuf.extend(filter(None, map(str, list)))
- if (self._wbufsize <= 1 or
- self._get_wbuf_len() >= self._wbufsize):
- self.flush()
-
- def _get_wbuf_len(self):
- buf_len = 0
- for x in self._wbuf:
- buf_len += len(x)
- return buf_len
-
- def read(self, size=-1):
- data = self._rbuf
- if size < 0:
- # Read until EOF
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- if self._rbufsize <= 1:
- recv_size = self.default_bufsize
- else:
- recv_size = self._rbufsize
- while True:
- data = self._sslobj.read(recv_size)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- else:
- # Read until size bytes or EOF seen, whichever comes first
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- left = size - buf_len
- recv_size = max(self._rbufsize, left)
- data = self._sslobj.read(recv_size)
- if not data:
- break
- buffers.append(data)
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
- def readline(self, size=-1):
- data = self._rbuf
- if size < 0:
- # Read until \n or EOF, whichever comes first
- if self._rbufsize <= 1:
- # Speed up unbuffered case
- assert data == ""
- buffers = []
- while data != "\n":
- data = self._sslobj.read(1)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self._sslobj.read(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- return "".join(buffers)
- else:
- # Read until size bytes or \n or EOF seen, whichever comes first
- nl = data.find('\n', 0, size)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self._sslobj.read(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- left = size - buf_len
- nl = data.find('\n', 0, left)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
- def readlines(self, sizehint=0):
- total = 0
- list = []
- while True:
- line = self.readline()
- if not line:
- break
- list.append(line)
- total += len(line)
- if sizehint and total >= sizehint:
- break
- return list
-
- # Iterator protocols
-
- def __iter__(self):
- return self
-
- def next(self):
- line = self.readline()
- if not line:
- raise StopIteration
- return line
-
+ self._makefile_refs += 1
+ return _fileobject(self, mode, bufsize)
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
- ssl_version=PROTOCOL_SSLv23, ca_certs=None):
+ ssl_version=PROTOCOL_SSLv23, ca_certs=None,
+ do_handshake_on_connect=True,
+ suppress_ragged_eofs=True):
return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
- ssl_version=ssl_version, ca_certs=ca_certs)
+ ssl_version=ssl_version, ca_certs=ca_certs,
+ do_handshake_on_connect=do_handshake_on_connect,
+ suppress_ragged_eofs=suppress_ragged_eofs)
+
# some utility functions
@@ -549,5 +396,7 @@ def sslwrap_simple (sock, keyfile=None, certfile=None):
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
- return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
- PROTOCOL_SSLv23, None)
+ ssl_sock = _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
+ PROTOCOL_SSLv23, None)
+ ssl_sock.do_handshake()
+ return ssl_sock
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index eb4d00c..d786154 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -3,7 +3,9 @@
import sys
import unittest
from test import test_support
+import asyncore
import socket
+import select
import errno
import subprocess
import time
@@ -97,8 +99,7 @@ class BasicTests(unittest.TestCase):
if (d1 != d2):
raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
-
-class NetworkTests(unittest.TestCase):
+class NetworkedTests(unittest.TestCase):
def testConnect(self):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@@ -130,6 +131,31 @@ class NetworkTests(unittest.TestCase):
finally:
s.close()
+
+ def testNonBlockingHandshake(self):
+ s = socket.socket(socket.AF_INET)
+ s.connect(("svn.python.org", 443))
+ s.setblocking(False)
+ s = ssl.wrap_socket(s,
+ cert_reqs=ssl.CERT_NONE,
+ do_handshake_on_connect=False)
+ count = 0
+ while True:
+ try:
+ count += 1
+ s.do_handshake()
+ break
+ except ssl.SSLError, err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ select.select([s], [], [])
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ select.select([], [s], [])
+ else:
+ raise
+ s.close()
+ if test_support.verbose:
+ sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
+
def testFetchServerCert(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
@@ -176,6 +202,18 @@ else:
threading.Thread.__init__(self)
self.setDaemon(True)
+ def show_conn_details(self):
+ if self.server.certreqs == ssl.CERT_REQUIRED:
+ cert = self.sslconn.getpeercert()
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
+ cert_binary = self.sslconn.getpeercert(True)
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
+ cipher = self.sslconn.cipher()
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+
def wrap_conn (self):
try:
self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
@@ -187,6 +225,7 @@ else:
if self.server.chatty:
handle_error("\n server: bad connection attempt from " +
str(self.sock.getpeername()) + ":\n")
+ self.close()
if not self.server.expect_bad_connects:
# here, we want to stop the server, because this shouldn't
# happen in the context of our test case
@@ -197,16 +236,6 @@ else:
return False
else:
- if self.server.certreqs == ssl.CERT_REQUIRED:
- cert = self.sslconn.getpeercert()
- if test_support.verbose and self.server.chatty:
- sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
- cert_binary = self.sslconn.getpeercert(True)
- if test_support.verbose and self.server.chatty:
- sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
- cipher = self.sslconn.cipher()
- if test_support.verbose and self.server.chatty:
- sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
return True
def read(self):
@@ -225,13 +254,16 @@ else:
if self.sslconn:
self.sslconn.close()
else:
- self.sock.close()
+ self.sock._sock.close()
def run (self):
self.running = True
if not self.server.starttls_server:
- if not self.wrap_conn():
+ if isinstance(self.sock, ssl.SSLSocket):
+ self.sslconn = self.sock
+ elif not self.wrap_conn():
return
+ self.show_conn_details()
while self.running:
try:
msg = self.read()
@@ -270,7 +302,9 @@ else:
def __init__(self, certificate, ssl_version=None,
certreqs=None, cacerts=None, expect_bad_connects=False,
- chatty=True, connectionchatty=False, starttls_server=False):
+ chatty=True, connectionchatty=False, starttls_server=False,
+ wrap_accepting_socket=False):
+
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLSv1
if certreqs is None:
@@ -284,8 +318,16 @@ else:
self.connectionchatty = connectionchatty
self.starttls_server = starttls_server
self.sock = socket.socket()
- self.port = test_support.bind_port(self.sock)
self.flag = None
+ if wrap_accepting_socket:
+ self.sock = ssl.wrap_socket(self.sock, server_side=True,
+ certfile=self.certificate,
+ cert_reqs = self.certreqs,
+ ca_certs = self.cacerts,
+ ssl_version = self.protocol)
+ if test_support.verbose and self.chatty:
+ sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
+ self.port = test_support.bind_port(self.sock)
self.active = False
threading.Thread.__init__(self)
self.setDaemon(False)
@@ -316,13 +358,86 @@ else:
except:
if self.chatty:
handle_error("Test server failure:\n")
+ self.sock.close()
def stop (self):
self.active = False
- self.sock.close()
+ class AsyncoreEchoServer(threading.Thread):
+
+ class EchoServer (asyncore.dispatcher):
+
+ class ConnectionHandler (asyncore.dispatcher_with_send):
+
+ def __init__(self, conn, certfile):
+ asyncore.dispatcher_with_send.__init__(self, conn)
+ self.socket = ssl.wrap_socket(conn, server_side=True,
+ certfile=certfile,
+ do_handshake_on_connect=True)
+
+ def readable(self):
+ if isinstance(self.socket, ssl.SSLSocket):
+ while self.socket.pending() > 0:
+ self.handle_read_event()
+ return True
+
+ def handle_read(self):
+ data = self.recv(1024)
+ self.send(data.lower())
+
+ def handle_close(self):
+ if test_support.verbose:
+ sys.stdout.write(" server: closed connection %s\n" % self.socket)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, certfile):
+ self.certfile = certfile
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.port = test_support.bind_port(self.socket)
+ self.listen(5)
+
+ def handle_accept(self):
+ sock_obj, addr = self.accept()
+ if test_support.verbose:
+ sys.stdout.write(" server: new connection from %s:%s\n" %addr)
+ self.ConnectionHandler(sock_obj, self.certfile)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, certfile):
+ self.flag = None
+ self.active = False
+ self.server = self.EchoServer(certfile)
+ self.port = self.server.port
+ threading.Thread.__init__(self)
+ self.setDaemon(True)
+
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
+
+ def start (self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run (self):
+ self.active = True
+ if self.flag:
+ self.flag.set()
+ while self.active:
+ try:
+ asyncore.loop(1)
+ except:
+ pass
+
+ def stop (self):
+ self.active = False
+ self.server.close()
- class AsyncoreHTTPSServer(threading.Thread):
+ class SocketServerHTTPSServer(threading.Thread):
class HTTPSServer(HTTPServer):
@@ -335,6 +450,12 @@ else:
self.active_lock = threading.Lock()
self.allow_reuse_address = True
+ def __str__(self):
+ return ('<%s %s:%s>' %
+ (self.__class__.__name__,
+ self.server_name,
+ self.server_port))
+
def get_request (self):
# override this to wrap socket with SSL
sock, addr = self.socket.accept()
@@ -421,8 +542,8 @@ else:
# we override this to suppress logging unless "verbose"
if test_support.verbose:
- sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" %
- (self.server.server_name,
+ sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
+ (self.server.server_address,
self.server.server_port,
self.request.cipher(),
self.log_date_time_string(),
@@ -440,9 +561,7 @@ else:
self.setDaemon(True)
def __str__(self):
- return '<%s %s:%d>' % (self.__class__.__name__,
- self.server.server_name,
- self.server.server_port)
+ return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
self.flag = flag
@@ -487,14 +606,16 @@ else:
def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n",
- chatty=True, connectionchatty=False):
+ chatty=True, connectionchatty=False,
+ wrap_accepting_socket=False):
server = ThreadedEchoServer(certfile,
certreqs=certreqs,
ssl_version=protocol,
cacerts=cacertsfile,
chatty=chatty,
- connectionchatty=connectionchatty)
+ connectionchatty=connectionchatty,
+ wrap_accepting_socket=wrap_accepting_socket)
flag = threading.Event()
server.start(flag)
# wait for it to start
@@ -572,7 +693,7 @@ else:
ssl.get_protocol_name(server_protocol)))
- class ConnectedTests(unittest.TestCase):
+ class ThreadedTests(unittest.TestCase):
def testRudeShutdown(self):
@@ -600,7 +721,7 @@ else:
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(s)
- except socket.sslerror:
+ except IOError:
pass
else:
raise test_support.TestFailed(
@@ -680,6 +801,9 @@ else:
def testMalformedCert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"badcert.pem"))
+ def testWrongCert(self):
+ badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
+ "wrongcert.pem"))
def testMalformedKey(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
"badkey.pem"))
@@ -796,9 +920,9 @@ else:
server.stop()
server.join()
- def testAsyncore(self):
+ def testSocketServer(self):
- server = AsyncoreHTTPSServer(CERTFILE)
+ server = SocketServerHTTPSServer(CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
@@ -810,8 +934,8 @@ else:
d1 = open(CERTFILE, 'rb').read()
d2 = ''
# now fetch the same data from the HTTPS server
- url = 'https://%s:%d/%s' % (
- HOST, server.port, os.path.split(CERTFILE)[1])
+ url = 'https://127.0.0.1:%d/%s' % (
+ server.port, os.path.split(CERTFILE)[1])
f = urllib.urlopen(url)
dlen = f.info().getheader("content-length")
if dlen and (int(dlen) > 0):
@@ -834,6 +958,58 @@ else:
server.stop()
server.join()
+ def testWrappedAccept (self):
+
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
+ CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
+ chatty=True, connectionchatty=True,
+ wrap_accepting_socket=True)
+
+
+ def testAsyncoreServer (self):
+
+ indata = "TEST MESSAGE of mixed case\n"
+
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ server = AsyncoreEchoServer(CERTFILE)
+ flag = threading.Event()
+ server.start(flag)
+ # wait for it to start
+ flag.wait()
+ # try to connect
+ try:
+ try:
+ s = ssl.wrap_socket(socket.socket())
+ s.connect(('127.0.0.1', server.port))
+ except ssl.SSLError, x:
+ raise test_support.TestFailed("Unexpected SSL error: " + str(x))
+ except Exception, x:
+ raise test_support.TestFailed("Unexpected exception: " + str(x))
+ else:
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: sending %s...\n" % (repr(indata)))
+ s.write(indata)
+ outdata = s.read()
+ if test_support.verbose:
+ sys.stdout.write(" client: read %s\n" % repr(outdata))
+ if outdata != indata.lower():
+ raise test_support.TestFailed(
+ "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+ % (outdata[:min(len(outdata),20)], len(outdata),
+ indata[:min(len(indata),20)].lower(), len(indata)))
+ s.write("over\n")
+ if test_support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+ finally:
+ server.stop()
+ # wait for server thread to end
+ server.join()
+
def test_main(verbose=False):
if skip_expected:
@@ -850,15 +1026,19 @@ def test_main(verbose=False):
not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
raise test_support.TestFailed("Can't read certificate files!")
+ TESTPORT = test_support.find_unused_port()
+ if not TESTPORT:
+ raise test_support.TestFailed("Can't find open port to test servers on!")
+
tests = [BasicTests]
if test_support.is_resource_enabled('network'):
- tests.append(NetworkTests)
+ tests.append(NetworkedTests)
if _have_threads:
thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'):
- tests.append(ConnectedTests)
+ tests.append(ThreadedTests)
test_support.run_unittest(*tests)
diff --git a/Lib/test/wrongcert.pem b/Lib/test/wrongcert.pem
new file mode 100644
index 0000000..5f92f9b
--- /dev/null
+++ b/Lib/test/wrongcert.pem
@@ -0,0 +1,32 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH
+FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T
+f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB
+AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq
+1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW
+7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg
+SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe
+19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg
+ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666
+lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs
+ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv
+frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk
+2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc=
+-----END RSA PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
+BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
+aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF
+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
+ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
+gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6
++bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK
+24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G
+A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9
+OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt
+U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ
+fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E
+usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+
+43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw
+eSHj5jpC8iZKjCHBn+mAi4cQ514=
+-----END CERTIFICATE-----
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 3f167b3..8fe72a5 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -2,14 +2,15 @@
SSL support based on patches by Brian E Gallew and Laszlo Kovacs.
Re-worked a bit by Bill Janssen to add server-side support and
- certificate decoding.
+ certificate decoding. Chris Stawarz contributed some non-blocking
+ patches.
This module is imported by ssl.py. It should *not* be used
directly.
XXX should partial writes be enabled, SSL_MODE_ENABLE_PARTIAL_WRITE?
- XXX what about SSL_MODE_AUTO_RETRY
+ XXX what about SSL_MODE_AUTO_RETRY?
*/
#include "Python.h"
@@ -265,8 +266,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
PySSLObject *self;
char *errstr = NULL;
int ret;
- int err;
- int sockstate;
int verification_mode;
self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
@@ -388,57 +387,6 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
- /* Actually negotiate SSL connection */
- /* XXX If SSL_connect() returns 0, it's also a failure. */
- sockstate = 0;
- do {
- PySSL_BEGIN_ALLOW_THREADS
- if (socket_type == PY_SSL_CLIENT)
- ret = SSL_connect(self->ssl);
- else
- ret = SSL_accept(self->ssl);
- err = SSL_get_error(self->ssl, ret);
- PySSL_END_ALLOW_THREADS
- if(PyErr_CheckSignals()) {
- goto fail;
- }
- if (err == SSL_ERROR_WANT_READ) {
- sockstate = check_socket_and_wait_for_timeout(Sock, 0);
- } else if (err == SSL_ERROR_WANT_WRITE) {
- sockstate = check_socket_and_wait_for_timeout(Sock, 1);
- } else {
- sockstate = SOCKET_OPERATION_OK;
- }
- if (sockstate == SOCKET_HAS_TIMED_OUT) {
- PyErr_SetString(PySSLErrorObject,
- ERRSTR("The connect operation timed out"));
- goto fail;
- } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
- PyErr_SetString(PySSLErrorObject,
- ERRSTR("Underlying socket has been closed."));
- goto fail;
- } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
- PyErr_SetString(PySSLErrorObject,
- ERRSTR("Underlying socket too large for select()."));
- goto fail;
- } else if (sockstate == SOCKET_IS_NONBLOCKING) {
- break;
- }
- } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
- if (ret < 1) {
- PySSL_SetError(self, ret, __FILE__, __LINE__);
- goto fail;
- }
- self->ssl->debug = 1;
-
- PySSL_BEGIN_ALLOW_THREADS
- if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
- X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
- self->server, X509_NAME_MAXLEN);
- X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
- self->issuer, X509_NAME_MAXLEN);
- }
- PySSL_END_ALLOW_THREADS
self->Socket = Sock;
Py_INCREF(self->Socket);
return self;
@@ -488,6 +436,65 @@ PyDoc_STRVAR(ssl_doc,
/* SSL object methods */
+static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
+{
+ int ret;
+ int err;
+ int sockstate;
+
+ /* Actually negotiate SSL connection */
+ /* XXX If SSL_do_handshake() returns 0, it's also a failure. */
+ sockstate = 0;
+ do {
+ PySSL_BEGIN_ALLOW_THREADS
+ ret = SSL_do_handshake(self->ssl);
+ err = SSL_get_error(self->ssl, ret);
+ PySSL_END_ALLOW_THREADS
+ if(PyErr_CheckSignals()) {
+ return NULL;
+ }
+ if (err == SSL_ERROR_WANT_READ) {
+ sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
+ } else if (err == SSL_ERROR_WANT_WRITE) {
+ sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
+ } else {
+ sockstate = SOCKET_OPERATION_OK;
+ }
+ if (sockstate == SOCKET_HAS_TIMED_OUT) {
+ PyErr_SetString(PySSLErrorObject,
+ ERRSTR("The handshake operation timed out"));
+ return NULL;
+ } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
+ PyErr_SetString(PySSLErrorObject,
+ ERRSTR("Underlying socket has been closed."));
+ return NULL;
+ } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) {
+ PyErr_SetString(PySSLErrorObject,
+ ERRSTR("Underlying socket too large for select()."));
+ return NULL;
+ } else if (sockstate == SOCKET_IS_NONBLOCKING) {
+ break;
+ }
+ } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
+ if (ret < 1)
+ return PySSL_SetError(self, ret, __FILE__, __LINE__);
+ self->ssl->debug = 1;
+
+ if (self->peer_cert)
+ X509_free (self->peer_cert);
+ PySSL_BEGIN_ALLOW_THREADS
+ if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) {
+ X509_NAME_oneline(X509_get_subject_name(self->peer_cert),
+ self->server, X509_NAME_MAXLEN);
+ X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),
+ self->issuer, X509_NAME_MAXLEN);
+ }
+ PySSL_END_ALLOW_THREADS
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
static PyObject *
PySSL_server(PySSLObject *self)
{
@@ -1127,7 +1134,9 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing)
rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
PySSL_END_ALLOW_THREADS
+#ifdef HAVE_POLL
normal_return:
+#endif
/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
(when we are able to write or when there's something to read) */
return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK;
@@ -1140,10 +1149,16 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
int count;
int sockstate;
int err;
+ int nonblocking;
if (!PyArg_ParseTuple(args, "s#:write", &data, &count))
return NULL;
+ /* just in case the blocking state of the socket has been changed */
+ nonblocking = (self->Socket->sock_timeout >= 0.0);
+ BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
+ BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
@@ -1200,6 +1215,25 @@ PyDoc_STRVAR(PySSL_SSLwrite_doc,
Writes the string s into the SSL object. Returns the number\n\
of bytes written.");
+static PyObject *PySSL_SSLpending(PySSLObject *self)
+{
+ int count = 0;
+
+ PySSL_BEGIN_ALLOW_THREADS
+ count = SSL_pending(self->ssl);
+ PySSL_END_ALLOW_THREADS
+ if (count < 0)
+ return PySSL_SetError(self, count, __FILE__, __LINE__);
+ else
+ return PyInt_FromLong(count);
+}
+
+PyDoc_STRVAR(PySSL_SSLpending_doc,
+"pending() -> count\n\
+\n\
+Returns the number of already decrypted bytes available for read,\n\
+pending on the connection.\n");
+
static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
{
PyObject *buf;
@@ -1207,6 +1241,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
int len = 1024;
int sockstate;
int err;
+ int nonblocking;
if (!PyArg_ParseTuple(args, "|i:read", &len))
return NULL;
@@ -1214,6 +1249,11 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
if (!(buf = PyString_FromStringAndSize((char *) 0, len)))
return NULL;
+ /* just in case the blocking state of the socket has been changed */
+ nonblocking = (self->Socket->sock_timeout >= 0.0);
+ BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
+ BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
+
/* first check if there are bytes ready to be read */
PySSL_BEGIN_ALLOW_THREADS
count = SSL_pending(self->ssl);
@@ -1232,9 +1272,18 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
Py_DECREF(buf);
return NULL;
} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) {
- /* should contain a zero-length string */
- _PyString_Resize(&buf, 0);
- return buf;
+ if (SSL_get_shutdown(self->ssl) !=
+ SSL_RECEIVED_SHUTDOWN)
+ {
+ Py_DECREF(buf);
+ PyErr_SetString(PySSLErrorObject,
+ "Socket closed without SSL shutdown handshake");
+ return NULL;
+ } else {
+ /* should contain a zero-length string */
+ _PyString_Resize(&buf, 0);
+ return buf;
+ }
}
}
do {
@@ -1285,16 +1334,54 @@ PyDoc_STRVAR(PySSL_SSLread_doc,
\n\
Read up to len bytes from the SSL socket.");
+static PyObject *PySSL_SSLshutdown(PySSLObject *self)
+{
+ int err;
+
+ /* Guard against closed socket */
+ if (self->Socket->sock_fd < 0) {
+ PyErr_SetString(PySSLErrorObject,
+ "Underlying socket has been closed.");
+ return NULL;
+ }
+
+ PySSL_BEGIN_ALLOW_THREADS
+ err = SSL_shutdown(self->ssl);
+ if (err == 0) {
+ /* we need to call it again to finish the shutdown */
+ err = SSL_shutdown(self->ssl);
+ }
+ PySSL_END_ALLOW_THREADS
+
+ if (err < 0)
+ return PySSL_SetError(self, err, __FILE__, __LINE__);
+ else {
+ Py_INCREF(self->Socket);
+ return (PyObject *) (self->Socket);
+ }
+}
+
+PyDoc_STRVAR(PySSL_SSLshutdown_doc,
+"shutdown(s) -> socket\n\
+\n\
+Does the SSL shutdown handshake with the remote end, and returns\n\
+the underlying socket object.");
+
static PyMethodDef PySSLMethods[] = {
+ {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
PySSL_SSLwrite_doc},
{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS,
PySSL_SSLread_doc},
+ {"pending", (PyCFunction)PySSL_SSLpending, METH_NOARGS,
+ PySSL_SSLpending_doc},
{"server", (PyCFunction)PySSL_server, METH_NOARGS},
{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS},
{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
PySSL_peercert_doc},
{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
+ {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
+ PySSL_SSLshutdown_doc},
{NULL, NULL}
};