summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ssl.py65
-rw-r--r--Lib/test/test_ssl.py112
2 files changed, 163 insertions, 14 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index f3da464..df5e98e 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -99,7 +99,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
-from _ssl import _SSLContext, MemoryBIO
+from _ssl import _SSLContext, MemoryBIO, SSLSession
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError,
@@ -391,18 +391,18 @@ class SSLContext(_SSLContext):
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
- server_hostname=None):
+ server_hostname=None, session=None):
return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
- _context=self)
+ _context=self, _session=session)
def wrap_bio(self, incoming, outgoing, server_side=False,
- server_hostname=None):
+ server_hostname=None, session=None):
sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
server_hostname=server_hostname)
- return SSLObject(sslobj)
+ return SSLObject(sslobj, session=session)
def set_npn_protocols(self, npn_protocols):
protos = bytearray()
@@ -572,10 +572,12 @@ class SSLObject:
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""
- def __init__(self, sslobj, owner=None):
+ def __init__(self, sslobj, owner=None, session=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
+ if session is not None:
+ self._sslobj.session = session
@property
def context(self):
@@ -587,6 +589,20 @@ class SSLObject:
self._sslobj.context = ctx
@property
+ def session(self):
+ """The SSLSession for client socket."""
+ return self._sslobj.session
+
+ @session.setter
+ def session(self, session):
+ self._sslobj.session = session
+
+ @property
+ def session_reused(self):
+ """Was the client session reused during handshake"""
+ return self._sslobj.session_reused
+
+ @property
def server_side(self):
"""Whether this is a server-side socket."""
return self._sslobj.server_side
@@ -703,7 +719,7 @@ class SSLSocket(socket):
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None,
- _context=None):
+ _context=None, _session=None):
if _context:
self._context = _context
@@ -735,11 +751,16 @@ class SSLSocket(socket):
# mixed in.
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
raise NotImplementedError("only stream sockets are supported")
- if server_side and server_hostname:
- raise ValueError("server_hostname can only be specified "
- "in client mode")
+ if server_side:
+ if server_hostname:
+ raise ValueError("server_hostname can only be specified "
+ "in client mode")
+ if _session is not None:
+ raise ValueError("session can only be specified in "
+ "client mode")
if self._context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")
+ self._session = _session
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
@@ -775,7 +796,8 @@ class SSLSocket(socket):
try:
sslobj = self._context._wrap_socket(self, server_side,
server_hostname)
- self._sslobj = SSLObject(sslobj, owner=self)
+ self._sslobj = SSLObject(sslobj, owner=self,
+ session=self._session)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@@ -796,6 +818,24 @@ class SSLSocket(socket):
self._context = ctx
self._sslobj.context = ctx
+ @property
+ def session(self):
+ """The SSLSession for client socket."""
+ if self._sslobj is not None:
+ return self._sslobj.session
+
+ @session.setter
+ def session(self, session):
+ self._session = session
+ if self._sslobj is not None:
+ self._sslobj.session = session
+
+ @property
+ def session_reused(self):
+ """Was the client session reused during handshake"""
+ if self._sslobj is not None:
+ return self._sslobj.session_reused
+
def dup(self):
raise NotImplemented("Can't dup() %s instances" %
self.__class__.__name__)
@@ -1028,7 +1068,8 @@ class SSLSocket(socket):
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
- self._sslobj = SSLObject(sslobj, owner=self)
+ self._sslobj = SSLObject(sslobj, owner=self,
+ session=self._session)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index aed226c..61744ae 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -2163,7 +2163,8 @@ if _have_threads:
self.server.close()
def server_params_test(client_context, server_context, indata=b"FOO\n",
- chatty=True, connectionchatty=False, sni_name=None):
+ chatty=True, connectionchatty=False, sni_name=None,
+ session=None):
"""
Launch a server, connect a client to it and try various reads
and writes.
@@ -2174,7 +2175,7 @@ if _have_threads:
connectionchatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
- server_hostname=sni_name) as s:
+ server_hostname=sni_name, session=session) as s:
s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
@@ -2202,6 +2203,8 @@ if _have_threads:
'client_alpn_protocol': s.selected_alpn_protocol(),
'client_npn_protocol': s.selected_npn_protocol(),
'version': s.version(),
+ 'session_reused': s.session_reused,
+ 'session': s.session,
})
s.close()
stats['server_alpn_protocols'] = server.selected_alpn_protocols
@@ -3412,6 +3415,111 @@ if _have_threads:
s.sendfile(file)
self.assertEqual(s.recv(1024), TEST_DATA)
+ def test_session(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+
+ # first conncetion without session
+ stats = server_params_test(client_context, server_context)
+ session = stats['session']
+ self.assertTrue(session.id)
+ self.assertGreater(session.time, 0)
+ self.assertGreater(session.timeout, 0)
+ self.assertTrue(session.has_ticket)
+ if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
+ self.assertGreater(session.ticket_lifetime_hint, 0)
+ self.assertFalse(stats['session_reused'])
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 1)
+ self.assertEqual(sess_stat['hits'], 0)
+
+ # reuse session
+ stats = server_params_test(client_context, server_context, session=session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 2)
+ self.assertEqual(sess_stat['hits'], 1)
+ self.assertTrue(stats['session_reused'])
+ session2 = stats['session']
+ self.assertEqual(session2.id, session.id)
+ self.assertEqual(session2, session)
+ self.assertIsNot(session2, session)
+ self.assertGreaterEqual(session2.time, session.time)
+ self.assertGreaterEqual(session2.timeout, session.timeout)
+
+ # another one without session
+ stats = server_params_test(client_context, server_context)
+ self.assertFalse(stats['session_reused'])
+ session3 = stats['session']
+ self.assertNotEqual(session3.id, session.id)
+ self.assertNotEqual(session3, session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 3)
+ self.assertEqual(sess_stat['hits'], 1)
+
+ # reuse session again
+ stats = server_params_test(client_context, server_context, session=session)
+ self.assertTrue(stats['session_reused'])
+ session4 = stats['session']
+ self.assertEqual(session4.id, session.id)
+ self.assertEqual(session4, session)
+ self.assertGreaterEqual(session4.time, session.time)
+ self.assertGreaterEqual(session4.timeout, session.timeout)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 4)
+ self.assertEqual(sess_stat['hits'], 2)
+
+ def test_session_handling(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+
+ context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context2.verify_mode = ssl.CERT_REQUIRED
+ context2.load_verify_locations(CERTFILE)
+ context2.load_cert_chain(CERTFILE)
+
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ # session is None before handshake
+ self.assertEqual(s.session, None)
+ self.assertEqual(s.session_reused, None)
+ s.connect((HOST, server.port))
+ session = s.session
+ self.assertTrue(session)
+ with self.assertRaises(TypeError) as e:
+ s.session = object
+ self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
+
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ # cannot set session after handshake
+ with self.assertRaises(ValueError) as e:
+ s.session = session
+ self.assertEqual(str(e.exception),
+ 'Cannot set session after handshake.')
+
+ with context.wrap_socket(socket.socket()) as s:
+ # can set session before handshake and before the
+ # connection was established
+ s.session = session
+ s.connect((HOST, server.port))
+ self.assertEqual(s.session.id, session.id)
+ self.assertEqual(s.session, session)
+ self.assertEqual(s.session_reused, True)
+
+ with context2.wrap_socket(socket.socket()) as s:
+ # cannot re-use session with a different SSLContext
+ with self.assertRaises(ValueError) as e:
+ s.session = session
+ s.connect((HOST, server.port))
+ self.assertEqual(str(e.exception),
+ 'Session refers to a different SSLContext.')
+
def test_main(verbose=False):
if support.verbose: