summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/socket.py6
-rw-r--r--Lib/ssl.py46
-rw-r--r--Lib/test/test_ssl.py184
-rw-r--r--Modules/_ssl.c57
4 files changed, 225 insertions, 68 deletions
diff --git a/Lib/socket.py b/Lib/socket.py
index 62eb82d..eb87673 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -174,11 +174,13 @@ class socket(_socket.socket):
if self._closed:
self.close()
+ def _real_close(self):
+ _socket.socket.close(self)
+
def close(self):
self._closed = True
if self._io_refs <= 0:
- _socket.socket.close(self)
-
+ self._real_close()
def fromfd(fd, family, type, proto=0):
""" fromfd(fd, family, type[, proto]) -> socket object
diff --git a/Lib/ssl.py b/Lib/ssl.py
index be13866..c229cd3 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -80,6 +80,7 @@ from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
from socket import dup as _dup
import base64 # for DER-to-PEM translation
+import traceback
class SSLSocket(socket):
@@ -94,16 +95,13 @@ class SSLSocket(socket):
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True):
- self._base = None
-
if sock is not None:
- # copied this code from socket.accept()
- fd = sock.fileno()
- nfd = _dup(fd)
- socket.__init__(self, family=sock.family, type=sock.type,
- proto=sock.proto, fileno=nfd)
+ socket.__init__(self,
+ family=sock.family,
+ type=sock.type,
+ proto=sock.proto,
+ fileno=_dup(sock.fileno()))
sock.close()
- sock = None
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
@@ -136,10 +134,6 @@ class SSLSocket(socket):
self.close()
raise x
- if sock and (self.fileno() != sock.fileno()):
- self._base = sock
- else:
- self._base = None
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
@@ -156,19 +150,23 @@ class SSLSocket(socket):
# raise an exception here if you wish to check for spurious closes
pass
- def read(self, len=None, buffer=None):
+ def read(self, len=0, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
try:
if buffer:
- return self._sslobj.read(buffer, len)
+ v = self._sslobj.read(buffer, len)
else:
- return self._sslobj.read(len or 1024)
+ v = self._sslobj.read(len or 1024)
+ return v
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
- return b''
+ if buffer:
+ return 0
+ else:
+ return b''
else:
raise
@@ -269,7 +267,6 @@ class SSLSocket(socket):
while True:
try:
v = self.read(nbytes, buffer)
- sys.stdout.flush()
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
@@ -302,9 +299,7 @@ class SSLSocket(socket):
def _real_close(self):
self._sslobj = None
# self._closed = True
- if self._base:
- self._base.close()
- socket.close(self)
+ socket._real_close(self)
def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake."""
@@ -329,8 +324,12 @@ class SSLSocket(socket):
self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
- if self.do_handshake_on_connect:
- self.do_handshake()
+ try:
+ if self.do_handshake_on_connect:
+ self.do_handshake()
+ except:
+ self._sslobj = None
+ raise
def accept(self):
"""Accepts a new connection from a remote client, and returns
@@ -348,10 +347,11 @@ class SSLSocket(socket):
self.do_handshake_on_connect),
addr)
-
def __del__(self):
+ # sys.stderr.write("__del__ on %s\n" % repr(self))
self._real_close()
+
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 18df3f4..81943a5 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -13,6 +13,7 @@ import pprint
import urllib, urlparse
import shutil
import traceback
+import asyncore
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
@@ -79,27 +80,6 @@ class BasicTests(unittest.TestCase):
class NetworkedTests(unittest.TestCase):
- def testFetchServerCert(self):
-
- pem = ssl.get_server_certificate(("svn.python.org", 443))
- if not pem:
- raise test_support.TestFailed("No server certificate on svn.python.org:443!")
-
- try:
- pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
- except ssl.SSLError as x:
- #should fail
- if test_support.verbose:
- sys.stdout.write("%s\n" % x)
- else:
- raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
-
- pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
- if not pem:
- raise test_support.TestFailed("No server certificate on svn.python.org:443!")
- if test_support.verbose:
- sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
-
def testConnect(self):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@@ -155,6 +135,29 @@ class NetworkedTests(unittest.TestCase):
if test_support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
+ def testFetchServerCert(self):
+
+ pem = ssl.get_server_certificate(("svn.python.org", 443))
+ if not pem:
+ raise test_support.TestFailed("No server certificate on svn.python.org:443!")
+
+ return
+
+ try:
+ pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
+ except ssl.SSLError as x:
+ #should fail
+ if test_support.verbose:
+ sys.stdout.write("%s\n" % x)
+ else:
+ raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
+
+ pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
+ if not pem:
+ raise test_support.TestFailed("No server certificate on svn.python.org:443!")
+ if test_support.verbose:
+ sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
+
try:
import threading
@@ -333,7 +336,9 @@ else:
def stop (self):
self.active = False
- class AsyncoreHTTPSServer(threading.Thread):
+ class OurHTTPSServer(threading.Thread):
+
+ # This one's based on HTTPServer, which is based on SocketServer
class HTTPSServer(HTTPServer):
@@ -463,6 +468,92 @@ else:
self.server.server_close()
+ class AsyncoreEchoServer(threading.Thread):
+
+ # this one's based on asyncore.dispatcher
+
+ class EchoServer (asyncore.dispatcher):
+
+ class ConnectionHandler (asyncore.dispatcher_with_send):
+
+ def __init__(self, conn, certfile):
+ self.socket = ssl.wrap_socket(conn, server_side=True,
+ certfile=certfile,
+ do_handshake_on_connect=False)
+ asyncore.dispatcher_with_send.__init__(self, self.socket)
+ # now we have to do the handshake
+ # we'll just do it the easy way, and block the connection
+ # till it's finished. If we were doing it right, we'd
+ # do this in multiple calls to handle_read...
+ self.do_handshake(block=True)
+
+ def readable(self):
+ if isinstance(self.socket, ssl.SSLSocket):
+ while self.socket.pending() > 0:
+ self.handle_read_event()
+ return True
+
+ def handle_read(self):
+ data = self.recv(1024)
+ if test_support.verbose:
+ sys.stdout.write(" server: read %s from client\n" % repr(data))
+ if not data:
+ self.close()
+ else:
+ self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict'))
+
+ def handle_close(self):
+ if test_support.verbose:
+ sys.stdout.write(" server: closed connection %s\n" % self.socket)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, port, certfile):
+ self.port = port
+ self.certfile = certfile
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.bind(('', port))
+ self.listen(5)
+
+ def handle_accept(self):
+ sock_obj, addr = self.accept()
+ if test_support.verbose:
+ sys.stdout.write(" server: new connection from %s:%s\n" %addr)
+ self.ConnectionHandler(sock_obj, self.certfile)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, port, certfile):
+ self.flag = None
+ self.active = False
+ self.server = self.EchoServer(port, certfile)
+ threading.Thread.__init__(self)
+ self.setDaemon(True)
+
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
+
+ def start (self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run (self):
+ self.active = True
+ if self.flag:
+ self.flag.set()
+ while self.active:
+ try:
+ asyncore.loop(1)
+ except:
+ pass
+
+ def stop (self):
+ self.active = False
+ self.server.close()
+
def badCertTest (certfile):
server = ThreadedEchoServer(TESTPORT, CERTFILE,
certreqs=ssl.CERT_REQUIRED,
@@ -509,6 +600,7 @@ else:
client_protocol = protocol
try:
s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
certfile=client_certfile,
ca_certs=cacertsfile,
cert_reqs=certreqs,
@@ -811,11 +903,9 @@ else:
server.stop()
server.join()
- class AsyncoreTests(unittest.TestCase):
-
- def testAsyncore(self):
+ def testSocketServer(self):
- server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
+ server = OurHTTPSServer(TESTPORT, CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
@@ -853,6 +943,47 @@ else:
server.stop()
server.join()
+ def testAsyncoreServer(self):
+
+ if test_support.verbose:
+ sys.stdout.write("\n")
+
+ indata="FOO\n"
+ server = AsyncoreEchoServer(TESTPORT, CERTFILE)
+ flag = threading.Event()
+ server.start(flag)
+ # wait for it to start
+ flag.wait()
+ # try to connect
+ try:
+ s = ssl.wrap_socket(socket.socket())
+ s.connect(('127.0.0.1', TESTPORT))
+ except ssl.SSLError as x:
+ raise test_support.TestFailed("Unexpected SSL error: " + str(x))
+ except Exception as x:
+ raise test_support.TestFailed("Unexpected exception: " + str(x))
+ else:
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: sending %s...\n" % (repr(indata)))
+ s.sendall(indata.encode('ASCII', 'strict'))
+ outdata = s.recv()
+ if test_support.verbose:
+ sys.stdout.write(" client: read %s\n" % repr(outdata))
+ outdata = str(outdata, 'ASCII', 'strict')
+ if outdata != indata.lower():
+ raise test_support.TestFailed(
+ "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+ % (repr(outdata[:min(len(outdata),20)]), len(outdata),
+ repr(indata[:min(len(indata),20)].lower()), len(indata)))
+ s.write("over\n".encode("ASCII", "strict"))
+ if test_support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+ finally:
+ server.stop()
+ server.join()
+
def findtestsocket(start, end):
def testbind(i):
@@ -900,7 +1031,6 @@ def test_main(verbose=False):
thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'):
tests.append(ThreadedTests)
- tests.append(AsyncoreTests)
test_support.run_unittest(*tests)
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index bd3f172..7ab2297 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -46,6 +46,7 @@ enum py_ssl_error {
PY_SSL_ERROR_WANT_CONNECT,
/* start of non ssl.h errorcodes */
PY_SSL_ERROR_EOF, /* special case of SSL_ERROR_SYSCALL */
+ PY_SSL_ERROR_NO_SOCKET, /* socket has been GC'd */
PY_SSL_ERROR_INVALID_ERROR_CODE
};
@@ -111,7 +112,7 @@ static unsigned int _ssl_locks_count = 0;
typedef struct {
PyObject_HEAD
- PySocketSockObject *Socket; /* Socket on which we're layered */
+ PyObject *Socket; /* weakref to socket on which we're layered */
SSL_CTX* ctx;
SSL* ssl;
X509* peer_cert;
@@ -188,13 +189,15 @@ PySSL_SetError(PySSLObject *obj, int ret, char *filename, int lineno)
{
unsigned long e = ERR_get_error();
if (e == 0) {
- if (ret == 0 || !obj->Socket) {
+ PySocketSockObject *s
+ = (PySocketSockObject *) PyWeakref_GetObject(obj->Socket);
+ if (ret == 0 || (((PyObject *)s) == Py_None)) {
p = PY_SSL_ERROR_EOF;
errstr =
"EOF occurred in violation of protocol";
} else if (ret == -1) {
/* underlying BIO reported an I/O error */
- return obj->Socket->errorhandler();
+ return s->errorhandler();
} else { /* possible? */
p = PY_SSL_ERROR_SYSCALL;
errstr = "Some I/O error occurred";
@@ -383,8 +386,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
- self->Socket = Sock;
- Py_INCREF(self->Socket);
+ self->Socket = PyWeakref_NewRef((PyObject *) Sock, Py_None);
return self;
fail:
if (errstr)
@@ -442,6 +444,14 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
sockstate = 0;
do {
+ PySocketSockObject *sock
+ = (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
+ if (((PyObject*)sock) == Py_None) {
+ _setSSLError("Underlying socket connection gone",
+ PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
+ return NULL;
+ }
+
PySSL_BEGIN_ALLOW_THREADS
ret = SSL_do_handshake(self->ssl);
err = SSL_get_error(self->ssl, ret);
@@ -450,9 +460,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
return NULL;
}
if (err == SSL_ERROR_WANT_READ) {
- sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
+ sockstate = check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
- sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
+ sockstate = check_socket_and_wait_for_timeout(sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
@@ -1140,16 +1150,24 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
int sockstate;
int err;
int nonblocking;
+ PySocketSockObject *sock
+ = (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
+
+ if (((PyObject*)sock) == Py_None) {
+ _setSSLError("Underlying socket connection gone",
+ PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
+ return NULL;
+ }
if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
return NULL;
/* just in case the blocking state of the socket has been changed */
- nonblocking = (self->Socket->sock_timeout >= 0.0);
+ nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
- sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
+ sockstate = check_socket_and_wait_for_timeout(sock, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
"The write operation timed out");
@@ -1174,10 +1192,10 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
}
if (err == SSL_ERROR_WANT_READ) {
sockstate =
- check_socket_and_wait_for_timeout(self->Socket, 0);
+ check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate =
- check_socket_and_wait_for_timeout(self->Socket, 1);
+ check_socket_and_wait_for_timeout(sock, 1);
} else {
sockstate = SOCKET_OPERATION_OK;
}
@@ -1233,10 +1251,17 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
int sockstate;
int err;
int nonblocking;
+ PySocketSockObject *sock
+ = (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
+
+ if (((PyObject*)sock) == Py_None) {
+ _setSSLError("Underlying socket connection gone",
+ PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
+ return NULL;
+ }
if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
return NULL;
-
if ((buf == NULL) || (buf == Py_None)) {
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
return NULL;
@@ -1254,7 +1279,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
}
/* just in case the blocking state of the socket has been changed */
- nonblocking = (self->Socket->sock_timeout >= 0.0);
+ nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
@@ -1264,7 +1289,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
PySSL_END_ALLOW_THREADS
if (!count) {
- sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
+ sockstate = check_socket_and_wait_for_timeout(sock, 0);
if (sockstate == SOCKET_HAS_TIMED_OUT) {
PyErr_SetString(PySSLErrorObject,
"The read operation timed out");
@@ -1299,10 +1324,10 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
}
if (err == SSL_ERROR_WANT_READ) {
sockstate =
- check_socket_and_wait_for_timeout(self->Socket, 0);
+ check_socket_and_wait_for_timeout(sock, 0);
} else if (err == SSL_ERROR_WANT_WRITE) {
sockstate =
- check_socket_and_wait_for_timeout(self->Socket, 1);
+ check_socket_and_wait_for_timeout(sock, 1);
} else if ((err == SSL_ERROR_ZERO_RETURN) &&
(SSL_get_shutdown(self->ssl) ==
SSL_RECEIVED_SHUTDOWN))