summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/sslproto.py
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2021-05-02 21:34:15 (GMT)
committerGitHub <noreply@github.com>2021-05-02 21:34:15 (GMT)
commit5fb06edbbb769561e245d0fe13002bab50e2ae60 (patch)
treea6341e32a1140447b2d37a3a47fedb9d5043c75d /Lib/asyncio/sslproto.py
parentc96cc089f60d2bf7e003c27413c3239ee9de2990 (diff)
downloadcpython-5fb06edbbb769561e245d0fe13002bab50e2ae60.zip
cpython-5fb06edbbb769561e245d0fe13002bab50e2ae60.tar.gz
cpython-5fb06edbbb769561e245d0fe13002bab50e2ae60.tar.bz2
bpo-44011: New asyncio ssl implementation (#17975)
Diffstat (limited to 'Lib/asyncio/sslproto.py')
-rw-r--r--Lib/asyncio/sslproto.py1058
1 files changed, 621 insertions, 437 deletions
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index cad25b2..e71875b 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -1,4 +1,5 @@
import collections
+import enum
import warnings
try:
import ssl
@@ -6,10 +7,37 @@ except ImportError: # pragma: no cover
ssl = None
from . import constants
+from . import exceptions
from . import protocols
from . import transports
from .log import logger
+SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
+
+
+class SSLProtocolState(enum.Enum):
+ UNWRAPPED = "UNWRAPPED"
+ DO_HANDSHAKE = "DO_HANDSHAKE"
+ WRAPPED = "WRAPPED"
+ FLUSHING = "FLUSHING"
+ SHUTDOWN = "SHUTDOWN"
+
+
+class AppProtocolState(enum.Enum):
+ # This tracks the state of app protocol (https://git.io/fj59P):
+ #
+ # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
+ #
+ # * cm: connection_made()
+ # * dr: data_received()
+ # * er: eof_received()
+ # * cl: connection_lost()
+
+ STATE_INIT = "STATE_INIT"
+ STATE_CON_MADE = "STATE_CON_MADE"
+ STATE_EOF = "STATE_EOF"
+ STATE_CON_LOST = "STATE_CON_LOST"
+
def _create_transport_context(server_side, server_hostname):
if server_side:
@@ -25,269 +53,35 @@ def _create_transport_context(server_side, server_hostname):
return sslcontext
-# States of an _SSLPipe.
-_UNWRAPPED = "UNWRAPPED"
-_DO_HANDSHAKE = "DO_HANDSHAKE"
-_WRAPPED = "WRAPPED"
-_SHUTDOWN = "SHUTDOWN"
-
-
-class _SSLPipe(object):
- """An SSL "Pipe".
-
- An SSL pipe allows you to communicate with an SSL/TLS protocol instance
- through memory buffers. It can be used to implement a security layer for an
- existing connection where you don't have access to the connection's file
- descriptor, or for some reason you don't want to use it.
-
- An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
- data is passed through untransformed. In wrapped mode, application level
- data is encrypted to SSL record level data and vice versa. The SSL record
- level is the lowest level in the SSL protocol suite and is what travels
- as-is over the wire.
-
- An SslPipe initially is in "unwrapped" mode. To start SSL, call
- do_handshake(). To shutdown SSL again, call unwrap().
- """
-
- max_size = 256 * 1024 # Buffer size passed to read()
-
- def __init__(self, context, server_side, server_hostname=None):
- """
- The *context* argument specifies the ssl.SSLContext to use.
-
- The *server_side* argument indicates whether this is a server side or
- client side transport.
-
- The optional *server_hostname* argument can be used to specify the
- hostname you are connecting to. You may only specify this parameter if
- the _ssl module supports Server Name Indication (SNI).
- """
- self._context = context
- self._server_side = server_side
- self._server_hostname = server_hostname
- self._state = _UNWRAPPED
- self._incoming = ssl.MemoryBIO()
- self._outgoing = ssl.MemoryBIO()
- self._sslobj = None
- self._need_ssldata = False
- self._handshake_cb = None
- self._shutdown_cb = None
-
- @property
- def context(self):
- """The SSL context passed to the constructor."""
- return self._context
-
- @property
- def ssl_object(self):
- """The internal ssl.SSLObject instance.
-
- Return None if the pipe is not wrapped.
- """
- return self._sslobj
-
- @property
- def need_ssldata(self):
- """Whether more record level data is needed to complete a handshake
- that is currently in progress."""
- return self._need_ssldata
-
- @property
- def wrapped(self):
- """
- Whether a security layer is currently in effect.
-
- Return False during handshake.
- """
- return self._state == _WRAPPED
-
- def do_handshake(self, callback=None):
- """Start the SSL handshake.
-
- Return a list of ssldata. A ssldata element is a list of buffers
-
- The optional *callback* argument can be used to install a callback that
- will be called when the handshake is complete. The callback will be
- called with None if successful, else an exception instance.
- """
- if self._state != _UNWRAPPED:
- raise RuntimeError('handshake in progress or completed')
- self._sslobj = self._context.wrap_bio(
- self._incoming, self._outgoing,
- server_side=self._server_side,
- server_hostname=self._server_hostname)
- self._state = _DO_HANDSHAKE
- self._handshake_cb = callback
- ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
- assert len(appdata) == 0
- return ssldata
-
- def shutdown(self, callback=None):
- """Start the SSL shutdown sequence.
-
- Return a list of ssldata. A ssldata element is a list of buffers
-
- The optional *callback* argument can be used to install a callback that
- will be called when the shutdown is complete. The callback will be
- called without arguments.
- """
- if self._state == _UNWRAPPED:
- raise RuntimeError('no security layer present')
- if self._state == _SHUTDOWN:
- raise RuntimeError('shutdown in progress')
- assert self._state in (_WRAPPED, _DO_HANDSHAKE)
- self._state = _SHUTDOWN
- self._shutdown_cb = callback
- ssldata, appdata = self.feed_ssldata(b'')
- assert appdata == [] or appdata == [b'']
- return ssldata
-
- def feed_eof(self):
- """Send a potentially "ragged" EOF.
-
- This method will raise an SSL_ERROR_EOF exception if the EOF is
- unexpected.
- """
- self._incoming.write_eof()
- ssldata, appdata = self.feed_ssldata(b'')
- assert appdata == [] or appdata == [b'']
-
- def feed_ssldata(self, data, only_handshake=False):
- """Feed SSL record level data into the pipe.
-
- The data must be a bytes instance. It is OK to send an empty bytes
- instance. This can be used to get ssldata for a handshake initiated by
- this endpoint.
-
- Return a (ssldata, appdata) tuple. The ssldata element is a list of
- buffers containing SSL data that needs to be sent to the remote SSL.
-
- The appdata element is a list of buffers containing plaintext data that
- needs to be forwarded to the application. The appdata list may contain
- an empty buffer indicating an SSL "close_notify" alert. This alert must
- be acknowledged by calling shutdown().
- """
- if self._state == _UNWRAPPED:
- # If unwrapped, pass plaintext data straight through.
- if data:
- appdata = [data]
- else:
- appdata = []
- return ([], appdata)
-
- self._need_ssldata = False
- if data:
- self._incoming.write(data)
-
- ssldata = []
- appdata = []
- try:
- if self._state == _DO_HANDSHAKE:
- # Call do_handshake() until it doesn't raise anymore.
- self._sslobj.do_handshake()
- self._state = _WRAPPED
- if self._handshake_cb:
- self._handshake_cb(None)
- if only_handshake:
- return (ssldata, appdata)
- # Handshake done: execute the wrapped block
-
- if self._state == _WRAPPED:
- # Main state: read data from SSL until close_notify
- while True:
- chunk = self._sslobj.read(self.max_size)
- appdata.append(chunk)
- if not chunk: # close_notify
- break
+def add_flowcontrol_defaults(high, low, kb):
+ if high is None:
+ if low is None:
+ hi = kb * 1024
+ else:
+ lo = low
+ hi = 4 * lo
+ else:
+ hi = high
+ if low is None:
+ lo = hi // 4
+ else:
+ lo = low
- elif self._state == _SHUTDOWN:
- # Call shutdown() until it doesn't raise anymore.
- self._sslobj.unwrap()
- self._sslobj = None
- self._state = _UNWRAPPED
- if self._shutdown_cb:
- self._shutdown_cb()
-
- elif self._state == _UNWRAPPED:
- # Drain possible plaintext data after close_notify.
- appdata.append(self._incoming.read())
- except (ssl.SSLError, ssl.CertificateError) as exc:
- exc_errno = getattr(exc, 'errno', None)
- if exc_errno not in (
- ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
- ssl.SSL_ERROR_SYSCALL):
- if self._state == _DO_HANDSHAKE and self._handshake_cb:
- self._handshake_cb(exc)
- raise
- self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
- # Check for record level data that needs to be sent back.
- # Happens for the initial handshake and renegotiations.
- if self._outgoing.pending:
- ssldata.append(self._outgoing.read())
- return (ssldata, appdata)
-
- def feed_appdata(self, data, offset=0):
- """Feed plaintext data into the pipe.
-
- Return an (ssldata, offset) tuple. The ssldata element is a list of
- buffers containing record level data that needs to be sent to the
- remote SSL instance. The offset is the number of plaintext bytes that
- were processed, which may be less than the length of data.
-
- NOTE: In case of short writes, this call MUST be retried with the SAME
- buffer passed into the *data* argument (i.e. the id() must be the
- same). This is an OpenSSL requirement. A further particularity is that
- a short write will always have offset == 0, because the _ssl module
- does not enable partial writes. And even though the offset is zero,
- there will still be encrypted data in ssldata.
- """
- assert 0 <= offset <= len(data)
- if self._state == _UNWRAPPED:
- # pass through data in unwrapped mode
- if offset < len(data):
- ssldata = [data[offset:]]
- else:
- ssldata = []
- return (ssldata, len(data))
+ if not hi >= lo >= 0:
+ raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+ (hi, lo))
- ssldata = []
- view = memoryview(data)
- while True:
- self._need_ssldata = False
- try:
- if offset < len(view):
- offset += self._sslobj.write(view[offset:])
- except ssl.SSLError as exc:
- # It is not allowed to call write() after unwrap() until the
- # close_notify is acknowledged. We return the condition to the
- # caller as a short write.
- exc_errno = getattr(exc, 'errno', None)
- if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
- exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
- if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE,
- ssl.SSL_ERROR_SYSCALL):
- raise
- self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
- # See if there's any record level data back for us.
- if self._outgoing.pending:
- ssldata.append(self._outgoing.read())
- if offset == len(view) or self._need_ssldata:
- break
- return (ssldata, offset)
+ return hi, lo
class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport):
+ _start_tls_compatible = True
_sendfile_compatible = constants._SendfileMode.FALLBACK
def __init__(self, loop, ssl_protocol):
self._loop = loop
- # SSLProtocol instance
self._ssl_protocol = ssl_protocol
self._closed = False
@@ -315,16 +109,15 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
self._closed = True
self._ssl_protocol._start_shutdown()
- def __del__(self, _warn=warnings.warn):
+ def __del__(self, _warnings=warnings):
if not self._closed:
- _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
- self.close()
+ self._closed = True
+ _warnings.warn(
+ "unclosed transport <asyncio._SSLProtocolTransport "
+ "object>", ResourceWarning)
def is_reading(self):
- tr = self._ssl_protocol._transport
- if tr is None:
- raise RuntimeError('SSL transport has not been initialized yet')
- return tr.is_reading()
+ return not self._ssl_protocol._app_reading_paused
def pause_reading(self):
"""Pause the receiving end.
@@ -332,7 +125,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
No data will be passed to the protocol's data_received()
method until resume_reading() is called.
"""
- self._ssl_protocol._transport.pause_reading()
+ self._ssl_protocol._pause_reading()
def resume_reading(self):
"""Resume the receiving end.
@@ -340,7 +133,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
Data received will once again be passed to the protocol's
data_received() method.
"""
- self._ssl_protocol._transport.resume_reading()
+ self._ssl_protocol._resume_reading()
def set_write_buffer_limits(self, high=None, low=None):
"""Set the high- and low-water limits for write flow control.
@@ -361,16 +154,51 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
reduces opportunities for doing I/O and computation
concurrently.
"""
- self._ssl_protocol._transport.set_write_buffer_limits(high, low)
+ self._ssl_protocol._set_write_buffer_limits(high, low)
+ self._ssl_protocol._control_app_writing()
+
+ def get_write_buffer_limits(self):
+ return (self._ssl_protocol._outgoing_low_water,
+ self._ssl_protocol._outgoing_high_water)
def get_write_buffer_size(self):
- """Return the current size of the write buffer."""
- return self._ssl_protocol._transport.get_write_buffer_size()
+ """Return the current size of the write buffers."""
+ return self._ssl_protocol._get_write_buffer_size()
+
+ def set_read_buffer_limits(self, high=None, low=None):
+ """Set the high- and low-water limits for read flow control.
+
+ These two values control when to call the upstream transport's
+ pause_reading() and resume_reading() methods. If specified,
+ the low-water limit must be less than or equal to the
+ high-water limit. Neither value can be negative.
+
+ The defaults are implementation-specific. If only the
+ high-water limit is given, the low-water limit defaults to an
+ implementation-specific value less than or equal to the
+ high-water limit. Setting high to zero forces low to zero as
+ well, and causes pause_reading() to be called whenever the
+ buffer becomes non-empty. Setting low to zero causes
+ resume_reading() to be called only once the buffer is empty.
+ Use of zero for either limit is generally sub-optimal as it
+ reduces opportunities for doing I/O and computation
+ concurrently.
+ """
+ self._ssl_protocol._set_read_buffer_limits(high, low)
+ self._ssl_protocol._control_ssl_reading()
+
+ def get_read_buffer_limits(self):
+ return (self._ssl_protocol._incoming_low_water,
+ self._ssl_protocol._incoming_high_water)
+
+ def get_read_buffer_size(self):
+ """Return the current size of the read buffer."""
+ return self._ssl_protocol._get_read_buffer_size()
@property
def _protocol_paused(self):
# Required for sendfile fallback pause_writing/resume_writing logic
- return self._ssl_protocol._transport._protocol_paused
+ return self._ssl_protocol._app_writing_paused
def write(self, data):
"""Write some data bytes to the transport.
@@ -383,7 +211,22 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
f"got {type(data).__name__}")
if not data:
return
- self._ssl_protocol._write_appdata(data)
+ self._ssl_protocol._write_appdata((data,))
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data bytes to the transport.
+
+ The default implementation concatenates the arguments and
+ calls write() on the result.
+ """
+ self._ssl_protocol._write_appdata(list_of_data)
+
+ def write_eof(self):
+ """Close the write end after flushing buffered data.
+
+ This raises :exc:`NotImplementedError` right now.
+ """
+ raise NotImplementedError
def can_write_eof(self):
"""Return True if this transport supports write_eof(), False if not."""
@@ -396,23 +239,36 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
The protocol's connection_lost() method will (eventually) be
called with None as its argument.
"""
+ self._closed = True
self._ssl_protocol._abort()
+
+ def _force_close(self, exc):
self._closed = True
+ self._ssl_protocol._abort(exc)
+ def _test__append_write_backlog(self, data):
+ # for test only
+ self._ssl_protocol._write_backlog.append(data)
+ self._ssl_protocol._write_buffer_size += len(data)
-class SSLProtocol(protocols.Protocol):
- """SSL protocol.
- Implementation of SSL on top of a socket using incoming and outgoing
- buffers which are ssl.MemoryBIO objects.
- """
+class SSLProtocol(protocols.BufferedProtocol):
+ max_size = 256 * 1024 # Buffer size passed to read()
+
+ _handshake_start_time = None
+ _handshake_timeout_handle = None
+ _shutdown_timeout_handle = None
def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
call_connection_made=True,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
if ssl is None:
- raise RuntimeError('stdlib ssl module not available')
+ raise RuntimeError("stdlib ssl module not available")
+
+ self._ssl_buffer = bytearray(self.max_size)
+ self._ssl_buffer_view = memoryview(self._ssl_buffer)
if ssl_handshake_timeout is None:
ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
@@ -420,6 +276,12 @@ class SSLProtocol(protocols.Protocol):
raise ValueError(
f"ssl_handshake_timeout should be a positive number, "
f"got {ssl_handshake_timeout}")
+ if ssl_shutdown_timeout is None:
+ ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
+ elif ssl_shutdown_timeout <= 0:
+ raise ValueError(
+ f"ssl_shutdown_timeout should be a positive number, "
+ f"got {ssl_shutdown_timeout}")
if not sslcontext:
sslcontext = _create_transport_context(
@@ -442,21 +304,54 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter
self._loop = loop
self._set_app_protocol(app_protocol)
- self._app_transport = _SSLProtocolTransport(self._loop, self)
- # _SSLPipe instance (None until the connection is made)
- self._sslpipe = None
- self._session_established = False
- self._in_handshake = False
- self._in_shutdown = False
+ self._app_transport = None
+ self._app_transport_created = False
# transport, ex: SelectorSocketTransport
self._transport = None
- self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
+ self._ssl_shutdown_timeout = ssl_shutdown_timeout
+ # SSL and state machine
+ self._incoming = ssl.MemoryBIO()
+ self._outgoing = ssl.MemoryBIO()
+ self._state = SSLProtocolState.UNWRAPPED
+ self._conn_lost = 0 # Set when connection_lost called
+ if call_connection_made:
+ self._app_state = AppProtocolState.STATE_INIT
+ else:
+ self._app_state = AppProtocolState.STATE_CON_MADE
+ self._sslobj = self._sslcontext.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=self._server_side,
+ server_hostname=self._server_hostname)
+
+ # Flow Control
+
+ self._ssl_writing_paused = False
+
+ self._app_reading_paused = False
+
+ self._ssl_reading_paused = False
+ self._incoming_high_water = 0
+ self._incoming_low_water = 0
+ self._set_read_buffer_limits()
+ self._eof_received = False
+
+ self._app_writing_paused = False
+ self._outgoing_high_water = 0
+ self._outgoing_low_water = 0
+ self._set_write_buffer_limits()
+ self._get_app_transport()
def _set_app_protocol(self, app_protocol):
self._app_protocol = app_protocol
- self._app_protocol_is_buffer = \
- isinstance(app_protocol, protocols.BufferedProtocol)
+ # Make fast hasattr check first
+ if (hasattr(app_protocol, 'get_buffer') and
+ isinstance(app_protocol, protocols.BufferedProtocol)):
+ self._app_protocol_get_buffer = app_protocol.get_buffer
+ self._app_protocol_buffer_updated = app_protocol.buffer_updated
+ self._app_protocol_is_buffer = True
+ else:
+ self._app_protocol_is_buffer = False
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
@@ -468,15 +363,20 @@ class SSLProtocol(protocols.Protocol):
self._waiter.set_result(None)
self._waiter = None
+ def _get_app_transport(self):
+ if self._app_transport is None:
+ if self._app_transport_created:
+ raise RuntimeError('Creating _SSLProtocolTransport twice')
+ self._app_transport = _SSLProtocolTransport(self._loop, self)
+ self._app_transport_created = True
+ return self._app_transport
+
def connection_made(self, transport):
"""Called when the low-level connection is made.
Start the SSL handshake.
"""
self._transport = transport
- self._sslpipe = _SSLPipe(self._sslcontext,
- self._server_side,
- self._server_hostname)
self._start_handshake()
def connection_lost(self, exc):
@@ -486,72 +386,58 @@ class SSLProtocol(protocols.Protocol):
meaning a regular EOF is received or the connection was
aborted or closed).
"""
- if self._session_established:
- self._session_established = False
- self._loop.call_soon(self._app_protocol.connection_lost, exc)
- else:
- # Most likely an exception occurred while in SSL handshake.
- # Just mark the app transport as closed so that its __del__
- # doesn't complain.
- if self._app_transport is not None:
- self._app_transport._closed = True
+ self._write_backlog.clear()
+ self._outgoing.read()
+ self._conn_lost += 1
+
+ # Just mark the app transport as closed so that its __dealloc__
+ # doesn't complain.
+ if self._app_transport is not None:
+ self._app_transport._closed = True
+
+ if self._state != SSLProtocolState.DO_HANDSHAKE:
+ if (
+ self._app_state == AppProtocolState.STATE_CON_MADE or
+ self._app_state == AppProtocolState.STATE_EOF
+ ):
+ self._app_state = AppProtocolState.STATE_CON_LOST
+ self._loop.call_soon(self._app_protocol.connection_lost, exc)
+ self._set_state(SSLProtocolState.UNWRAPPED)
self._transport = None
self._app_transport = None
- if getattr(self, '_handshake_timeout_handle', None):
- self._handshake_timeout_handle.cancel()
- self._wakeup_waiter(exc)
self._app_protocol = None
- self._sslpipe = None
+ self._wakeup_waiter(exc)
- def pause_writing(self):
- """Called when the low-level transport's buffer goes over
- the high-water mark.
- """
- self._app_protocol.pause_writing()
+ if self._shutdown_timeout_handle:
+ self._shutdown_timeout_handle.cancel()
+ self._shutdown_timeout_handle = None
+ if self._handshake_timeout_handle:
+ self._handshake_timeout_handle.cancel()
+ self._handshake_timeout_handle = None
- def resume_writing(self):
- """Called when the low-level transport's buffer drains below
- the low-water mark.
- """
- self._app_protocol.resume_writing()
+ def get_buffer(self, n):
+ want = n
+ if want <= 0 or want > self.max_size:
+ want = self.max_size
+ if len(self._ssl_buffer) < want:
+ self._ssl_buffer = bytearray(want)
+ self._ssl_buffer_view = memoryview(self._ssl_buffer)
+ return self._ssl_buffer_view
- def data_received(self, data):
- """Called when some SSL data is received.
+ def buffer_updated(self, nbytes):
+ self._incoming.write(self._ssl_buffer_view[:nbytes])
- The argument is a bytes object.
- """
- if self._sslpipe is None:
- # transport closing, sslpipe is destroyed
- return
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._do_handshake()
- try:
- ssldata, appdata = self._sslpipe.feed_ssldata(data)
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as e:
- self._fatal_error(e, 'SSL error in data received')
- return
+ elif self._state == SSLProtocolState.WRAPPED:
+ self._do_read()
- for chunk in ssldata:
- self._transport.write(chunk)
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_flush()
- for chunk in appdata:
- if chunk:
- try:
- if self._app_protocol_is_buffer:
- protocols._feed_data_to_buffered_proto(
- self._app_protocol, chunk)
- else:
- self._app_protocol.data_received(chunk)
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as ex:
- self._fatal_error(
- ex, 'application protocol failed to receive SSL data')
- return
- else:
- self._start_shutdown()
- break
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
def eof_received(self):
"""Called when the other end of the low-level stream
@@ -561,19 +447,32 @@ class SSLProtocol(protocols.Protocol):
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
+ self._eof_received = True
try:
if self._loop.get_debug():
logger.debug("%r received EOF", self)
- self._wakeup_waiter(ConnectionResetError)
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._on_handshake_complete(ConnectionResetError)
- if not self._in_handshake:
- keep_open = self._app_protocol.eof_received()
- if keep_open:
- logger.warning('returning true from eof_received() '
- 'has no effect when using ssl')
- finally:
+ elif self._state == SSLProtocolState.WRAPPED:
+ self._set_state(SSLProtocolState.FLUSHING)
+ if self._app_reading_paused:
+ return True
+ else:
+ self._do_flush()
+
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_write()
+ self._set_state(SSLProtocolState.SHUTDOWN)
+ self._do_shutdown()
+
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
+
+ except Exception:
self._transport.close()
+ raise
def _get_extra_info(self, name, default=None):
if name in self._extra:
@@ -583,19 +482,45 @@ class SSLProtocol(protocols.Protocol):
else:
return default
- def _start_shutdown(self):
- if self._in_shutdown:
- return
- if self._in_handshake:
- self._abort()
+ def _set_state(self, new_state):
+ allowed = False
+
+ if new_state == SSLProtocolState.UNWRAPPED:
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.UNWRAPPED and
+ new_state == SSLProtocolState.DO_HANDSHAKE
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.DO_HANDSHAKE and
+ new_state == SSLProtocolState.WRAPPED
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.WRAPPED and
+ new_state == SSLProtocolState.FLUSHING
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.FLUSHING and
+ new_state == SSLProtocolState.SHUTDOWN
+ ):
+ allowed = True
+
+ if allowed:
+ self._state = new_state
+
else:
- self._in_shutdown = True
- self._write_appdata(b'')
+ raise RuntimeError(
+ 'cannot switch state from {} to {}'.format(
+ self._state, new_state))
- def _write_appdata(self, data):
- self._write_backlog.append((data, 0))
- self._write_buffer_size += len(data)
- self._process_write_backlog()
+ # Handshake flow
def _start_handshake(self):
if self._loop.get_debug():
@@ -603,17 +528,18 @@ class SSLProtocol(protocols.Protocol):
self._handshake_start_time = self._loop.time()
else:
self._handshake_start_time = None
- self._in_handshake = True
- # (b'', 1) is a special value in _process_write_backlog() to do
- # the SSL handshake
- self._write_backlog.append((b'', 1))
+
+ self._set_state(SSLProtocolState.DO_HANDSHAKE)
+
+ # start handshake timeout count down
self._handshake_timeout_handle = \
self._loop.call_later(self._ssl_handshake_timeout,
- self._check_handshake_timeout)
- self._process_write_backlog()
+ lambda: self._check_handshake_timeout())
+
+ self._do_handshake()
def _check_handshake_timeout(self):
- if self._in_handshake is True:
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
msg = (
f"SSL handshake is taking longer than "
f"{self._ssl_handshake_timeout} seconds: "
@@ -621,24 +547,37 @@ class SSLProtocol(protocols.Protocol):
)
self._fatal_error(ConnectionAbortedError(msg))
+ def _do_handshake(self):
+ try:
+ self._sslobj.do_handshake()
+ except SSLAgainErrors:
+ self._process_outgoing()
+ except ssl.SSLError as exc:
+ self._on_handshake_complete(exc)
+ else:
+ self._on_handshake_complete(None)
+
def _on_handshake_complete(self, handshake_exc):
- self._in_handshake = False
- self._handshake_timeout_handle.cancel()
+ if self._handshake_timeout_handle is not None:
+ self._handshake_timeout_handle.cancel()
+ self._handshake_timeout_handle = None
- sslobj = self._sslpipe.ssl_object
+ sslobj = self._sslobj
try:
- if handshake_exc is not None:
+ if handshake_exc is None:
+ self._set_state(SSLProtocolState.WRAPPED)
+ else:
raise handshake_exc
peercert = sslobj.getpeercert()
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as exc:
+ except Exception as exc:
+ self._set_state(SSLProtocolState.UNWRAPPED)
if isinstance(exc, ssl.CertificateError):
msg = 'SSL handshake failed on verifying the certificate'
else:
msg = 'SSL handshake failed'
self._fatal_error(exc, msg)
+ self._wakeup_waiter(exc)
return
if self._loop.get_debug():
@@ -649,85 +588,330 @@ class SSLProtocol(protocols.Protocol):
self._extra.update(peercert=peercert,
cipher=sslobj.cipher(),
compression=sslobj.compression(),
- ssl_object=sslobj,
- )
- if self._call_connection_made:
- self._app_protocol.connection_made(self._app_transport)
+ ssl_object=sslobj)
+ if self._app_state == AppProtocolState.STATE_INIT:
+ self._app_state = AppProtocolState.STATE_CON_MADE
+ self._app_protocol.connection_made(self._get_app_transport())
self._wakeup_waiter()
- self._session_established = True
- # In case transport.write() was already called. Don't call
- # immediately _process_write_backlog(), but schedule it:
- # _on_handshake_complete() can be called indirectly from
- # _process_write_backlog(), and _process_write_backlog() is not
- # reentrant.
- self._loop.call_soon(self._process_write_backlog)
-
- def _process_write_backlog(self):
- # Try to make progress on the write backlog.
- if self._transport is None or self._sslpipe is None:
+ self._do_read()
+
+ # Shutdown flow
+
+ def _start_shutdown(self):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN,
+ SSLProtocolState.UNWRAPPED
+ )
+ ):
+ return
+ if self._app_transport is not None:
+ self._app_transport._closed = True
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._abort()
+ else:
+ self._set_state(SSLProtocolState.FLUSHING)
+ self._shutdown_timeout_handle = self._loop.call_later(
+ self._ssl_shutdown_timeout,
+ lambda: self._check_shutdown_timeout()
+ )
+ self._do_flush()
+
+ def _check_shutdown_timeout(self):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN
+ )
+ ):
+ self._transport._force_close(
+ exceptions.TimeoutError('SSL shutdown timed out'))
+
+ def _do_flush(self):
+ self._do_read()
+ self._set_state(SSLProtocolState.SHUTDOWN)
+ self._do_shutdown()
+
+ def _do_shutdown(self):
+ try:
+ if not self._eof_received:
+ self._sslobj.unwrap()
+ except SSLAgainErrors:
+ self._process_outgoing()
+ except ssl.SSLError as exc:
+ self._on_shutdown_complete(exc)
+ else:
+ self._process_outgoing()
+ self._call_eof_received()
+ self._on_shutdown_complete(None)
+
+ def _on_shutdown_complete(self, shutdown_exc):
+ if self._shutdown_timeout_handle is not None:
+ self._shutdown_timeout_handle.cancel()
+ self._shutdown_timeout_handle = None
+
+ if shutdown_exc:
+ self._fatal_error(shutdown_exc)
+ else:
+ self._loop.call_soon(self._transport.close)
+
+ def _abort(self):
+ self._set_state(SSLProtocolState.UNWRAPPED)
+ if self._transport is not None:
+ self._transport.abort()
+
+ # Outgoing flow
+
+ def _write_appdata(self, list_of_data):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN,
+ SSLProtocolState.UNWRAPPED
+ )
+ ):
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('SSL connection is closed')
+ self._conn_lost += 1
return
+ for data in list_of_data:
+ self._write_backlog.append(data)
+ self._write_buffer_size += len(data)
+
try:
- for i in range(len(self._write_backlog)):
- data, offset = self._write_backlog[0]
- if data:
- ssldata, offset = self._sslpipe.feed_appdata(data, offset)
- elif offset:
- ssldata = self._sslpipe.do_handshake(
- self._on_handshake_complete)
- offset = 1
+ if self._state == SSLProtocolState.WRAPPED:
+ self._do_write()
+
+ except Exception as ex:
+ self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+ def _do_write(self):
+ try:
+ while self._write_backlog:
+ data = self._write_backlog[0]
+ count = self._sslobj.write(data)
+ data_len = len(data)
+ if count < data_len:
+ self._write_backlog[0] = data[count:]
+ self._write_buffer_size -= count
else:
- ssldata = self._sslpipe.shutdown(self._finalize)
- offset = 1
-
- for chunk in ssldata:
- self._transport.write(chunk)
-
- if offset < len(data):
- self._write_backlog[0] = (data, offset)
- # A short write means that a write is blocked on a read
- # We need to enable reading if it is paused!
- assert self._sslpipe.need_ssldata
- if self._transport._paused:
- self._transport.resume_reading()
- break
+ del self._write_backlog[0]
+ self._write_buffer_size -= data_len
+ except SSLAgainErrors:
+ pass
+ self._process_outgoing()
+
+ def _process_outgoing(self):
+ if not self._ssl_writing_paused:
+ data = self._outgoing.read()
+ if len(data):
+ self._transport.write(data)
+ self._control_app_writing()
+
+ # Incoming flow
+
+ def _do_read(self):
+ if (
+ self._state not in (
+ SSLProtocolState.WRAPPED,
+ SSLProtocolState.FLUSHING,
+ )
+ ):
+ return
+ try:
+ if not self._app_reading_paused:
+ if self._app_protocol_is_buffer:
+ self._do_read__buffered()
+ else:
+ self._do_read__copied()
+ if self._write_backlog:
+ self._do_write()
+ else:
+ self._process_outgoing()
+ self._control_ssl_reading()
+ except Exception as ex:
+ self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+ def _do_read__buffered(self):
+ offset = 0
+ count = 1
+
+ buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
+ wants = len(buf)
- # An entire chunk from the backlog was processed. We can
- # delete it and reduce the outstanding buffer size.
- del self._write_backlog[0]
- self._write_buffer_size -= len(data)
- except (SystemExit, KeyboardInterrupt):
+ try:
+ count = self._sslobj.read(wants, buf)
+
+ if count > 0:
+ offset = count
+ while offset < wants:
+ count = self._sslobj.read(wants - offset, buf[offset:])
+ if count > 0:
+ offset += count
+ else:
+ break
+ else:
+ self._loop.call_soon(lambda: self._do_read())
+ except SSLAgainErrors:
+ pass
+ if offset > 0:
+ self._app_protocol_buffer_updated(offset)
+ if not count:
+ # close_notify
+ self._call_eof_received()
+ self._start_shutdown()
+
+ def _do_read__copied(self):
+ chunk = b'1'
+ zero = True
+ one = False
+
+ try:
+ while True:
+ chunk = self._sslobj.read(self.max_size)
+ if not chunk:
+ break
+ if zero:
+ zero = False
+ one = True
+ first = chunk
+ elif one:
+ one = False
+ data = [first, chunk]
+ else:
+ data.append(chunk)
+ except SSLAgainErrors:
+ pass
+ if one:
+ self._app_protocol.data_received(first)
+ elif not zero:
+ self._app_protocol.data_received(b''.join(data))
+ if not chunk:
+ # close_notify
+ self._call_eof_received()
+ self._start_shutdown()
+
+ def _call_eof_received(self):
+ try:
+ if self._app_state == AppProtocolState.STATE_CON_MADE:
+ self._app_state = AppProtocolState.STATE_EOF
+ keep_open = self._app_protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ except (KeyboardInterrupt, SystemExit):
raise
- except BaseException as exc:
- if self._in_handshake:
- # Exceptions will be re-raised in _on_handshake_complete.
- self._on_handshake_complete(exc)
- else:
- self._fatal_error(exc, 'Fatal error on SSL transport')
+ except BaseException as ex:
+ self._fatal_error(ex, 'Error calling eof_received()')
+
+ # Flow control for writes from APP socket
+
+ def _control_app_writing(self):
+ size = self._get_write_buffer_size()
+ if size >= self._outgoing_high_water and not self._app_writing_paused:
+ self._app_writing_paused = True
+ try:
+ self._app_protocol.pause_writing()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.pause_writing() failed',
+ 'exception': exc,
+ 'transport': self._app_transport,
+ 'protocol': self,
+ })
+ elif size <= self._outgoing_low_water and self._app_writing_paused:
+ self._app_writing_paused = False
+ try:
+ self._app_protocol.resume_writing()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.resume_writing() failed',
+ 'exception': exc,
+ 'transport': self._app_transport,
+ 'protocol': self,
+ })
+
+ def _get_write_buffer_size(self):
+ return self._outgoing.pending + self._write_buffer_size
+
+ def _set_write_buffer_limits(self, high=None, low=None):
+ high, low = add_flowcontrol_defaults(
+ high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
+ self._outgoing_high_water = high
+ self._outgoing_low_water = low
+
+ # Flow control for reads to APP socket
+
+ def _pause_reading(self):
+ self._app_reading_paused = True
+
+ def _resume_reading(self):
+ if self._app_reading_paused:
+ self._app_reading_paused = False
+
+ def resume():
+ if self._state == SSLProtocolState.WRAPPED:
+ self._do_read()
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_flush()
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
+ self._loop.call_soon(resume)
+
+ # Flow control for reads from SSL socket
+
+ def _control_ssl_reading(self):
+ size = self._get_read_buffer_size()
+ if size >= self._incoming_high_water and not self._ssl_reading_paused:
+ self._ssl_reading_paused = True
+ self._transport.pause_reading()
+ elif size <= self._incoming_low_water and self._ssl_reading_paused:
+ self._ssl_reading_paused = False
+ self._transport.resume_reading()
+
+ def _set_read_buffer_limits(self, high=None, low=None):
+ high, low = add_flowcontrol_defaults(
+ high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
+ self._incoming_high_water = high
+ self._incoming_low_water = low
+
+ def _get_read_buffer_size(self):
+ return self._incoming.pending
+
+ # Flow control for writes to SSL socket
+
+ def pause_writing(self):
+ """Called when the low-level transport's buffer goes over
+ the high-water mark.
+ """
+ assert not self._ssl_writing_paused
+ self._ssl_writing_paused = True
+
+ def resume_writing(self):
+ """Called when the low-level transport's buffer drains below
+ the low-water mark.
+ """
+ assert self._ssl_writing_paused
+ self._ssl_writing_paused = False
+ self._process_outgoing()
def _fatal_error(self, exc, message='Fatal error on transport'):
+ if self._transport:
+ self._transport._force_close(exc)
+
if isinstance(exc, OSError):
if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True)
- else:
+ elif not isinstance(exc, exceptions.CancelledError):
self._loop.call_exception_handler({
'message': message,
'exception': exc,
'transport': self._transport,
'protocol': self,
})
- if self._transport:
- self._transport._force_close(exc)
-
- def _finalize(self):
- self._sslpipe = None
-
- if self._transport is not None:
- self._transport.close()
-
- def _abort(self):
- try:
- if self._transport is not None:
- self._transport.abort()
- finally:
- self._finalize()