summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/ssl.py10
-rw-r--r--Lib/test/test_ssl.py23
-rw-r--r--Modules/_ssl.c40
3 files changed, 69 insertions, 4 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index c072cd9..aa301295 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -75,10 +75,10 @@ from _ssl import (
SSL_ERROR_INVALID_ERROR_CODE,
)
-from socket import socket, AF_INET, SOCK_STREAM, error
from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
from socket import dup as _dup
+from socket import socket, AF_INET, SOCK_STREAM
import base64 # for DER-to-PEM translation
import traceback
@@ -296,6 +296,14 @@ class SSLSocket(socket):
self._sslobj = None
socket.shutdown(self, how)
+ def unwrap (self):
+ if self._sslobj:
+ s = self._sslobj.shutdown()
+ self._sslobj = None
+ return s
+ else:
+ raise ValueError("No SSL wrapper around " + str(self))
+
def _real_close(self):
self._sslobj = None
# self._closed = True
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 9e36e80..a40a35d 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -279,6 +279,15 @@ else:
self.write("OK\n".encode("ASCII", "strict"))
if not self.wrap_conn():
return
+ elif (self.server.starttls_server and self.sslconn
+ and amsg.strip() == 'ENDTLS'):
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
+ self.write("OK\n".encode("ASCII", "strict"))
+ self.sock = self.sslconn.unwrap()
+ self.sslconn = None
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: connection is now unencrypted...\n")
else:
if (support.verbose and
self.server.connectionchatty):
@@ -868,7 +877,7 @@ else:
def testSTARTTLS (self):
- msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4")
+ msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
server = ThreadedEchoServer(CERTFILE,
ssl_version=ssl.PROTOCOL_TLSv1,
@@ -910,8 +919,16 @@ else:
" client: read %s from server, starting TLS...\n"
% repr(msg))
conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
-
wrapped = True
+ elif (indata == "ENDTLS" and
+ str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
+ if support.verbose:
+ msg = str(outdata, 'ASCII', 'replace')
+ sys.stdout.write(
+ " client: read %s from server, ending TLS...\n"
+ % repr(msg))
+ s = conn.unwrap()
+ wrapped = False
else:
if support.verbose:
msg = str(outdata, 'ASCII', 'replace')
@@ -922,7 +939,7 @@ else:
if wrapped:
conn.write("over\n".encode("ASCII", "strict"))
else:
- s.send("over\n")
+ s.send("over\n".encode("ASCII", "strict"))
if wrapped:
conn.close()
else:
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 48318a8..d9cbbd0 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -1370,6 +1370,42 @@ PyDoc_STRVAR(PySSL_SSLread_doc,
\n\
Read up to len bytes from the SSL socket.");
+static PyObject *PySSL_SSLshutdown(PySSLObject *self)
+{
+ int err;
+ PySocketSockObject *sock
+ = (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
+
+ /* Guard against closed socket */
+ if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) {
+ _setSSLError("Underlying socket connection gone",
+ PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
+ return NULL;
+ }
+
+ PySSL_BEGIN_ALLOW_THREADS
+ err = SSL_shutdown(self->ssl);
+ if (err == 0) {
+ /* we need to call it again to finish the shutdown */
+ err = SSL_shutdown(self->ssl);
+ }
+ PySSL_END_ALLOW_THREADS
+
+ if (err < 0)
+ return PySSL_SetError(self, err, __FILE__, __LINE__);
+ else {
+ Py_INCREF(sock);
+ return (PyObject *) sock;
+ }
+}
+
+PyDoc_STRVAR(PySSL_SSLshutdown_doc,
+"shutdown(s) -> socket\n\
+\n\
+Does the SSL shutdown handshake with the remote end, and returns\n\
+the underlying socket object.");
+
+
static PyMethodDef PySSLMethods[] = {
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
@@ -1381,6 +1417,8 @@ static PyMethodDef PySSLMethods[] = {
{"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS,
PySSL_peercert_doc},
{"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS},
+ {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
+ PySSL_SSLshutdown_doc},
{NULL, NULL}
};
@@ -1480,6 +1518,8 @@ fails or if it does provide enough data to seed PRNG.");
#endif
+
+
/* List of functions exported by this module. */
static PyMethodDef PySSL_methods[] = {