diff options
author | Pablo Galindo <Pablogsal@gmail.com> | 2021-05-03 15:21:59 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-03 15:21:59 (GMT) |
commit | 7719953b30430b351ba0f153c2b51b16cc68ee36 (patch) | |
tree | 8014086b85a13ed79d45e29ab74a9a9f5c9c68eb /Lib/asyncio/sslproto.py | |
parent | 39494285e15dc2d291ec13de5045b930eaf0a3db (diff) | |
download | cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.zip cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.gz cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.bz2 |
bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)
This reverts commit 5fb06edbbb769561e245d0fe13002bab50e2ae60 and all
subsequent dependent commits.
Diffstat (limited to 'Lib/asyncio/sslproto.py')
-rw-r--r-- | Lib/asyncio/sslproto.py | 1059 |
1 files changed, 437 insertions, 622 deletions
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 79734ab..cad25b2 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -1,5 +1,4 @@ import collections -import enum import warnings try: import ssl @@ -7,38 +6,10 @@ except ImportError: # pragma: no cover ssl = None from . import constants -from . import exceptions from . import protocols from . import transports from .log import logger -if ssl is not None: - SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) - - -class SSLProtocolState(enum.Enum): - UNWRAPPED = "UNWRAPPED" - DO_HANDSHAKE = "DO_HANDSHAKE" - WRAPPED = "WRAPPED" - FLUSHING = "FLUSHING" - SHUTDOWN = "SHUTDOWN" - - -class AppProtocolState(enum.Enum): - # This tracks the state of app protocol (https://git.io/fj59P): - # - # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST - # - # * cm: connection_made() - # * dr: data_received() - # * er: eof_received() - # * cl: connection_lost() - - STATE_INIT = "STATE_INIT" - STATE_CON_MADE = "STATE_CON_MADE" - STATE_EOF = "STATE_EOF" - STATE_CON_LOST = "STATE_CON_LOST" - def _create_transport_context(server_side, server_hostname): if server_side: @@ -54,35 +25,269 @@ def _create_transport_context(server_side, server_hostname): return sslcontext -def add_flowcontrol_defaults(high, low, kb): - if high is None: - if low is None: - hi = kb * 1024 - else: - lo = low - hi = 4 * lo - else: - hi = high - if low is None: - lo = hi // 4 - else: - lo = low +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" + + +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + do_handshake(). To shutdown SSL again, call unwrap(). + """ + + max_size = 256 * 1024 # Buffer size passed to read() + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the ssl.SSLContext to use. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = _UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + self._handshake_cb = None + self._shutdown_cb = None + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal ssl.SSLObject instance. + + Return None if the pipe is not wrapped. + """ + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """ + Whether a security layer is currently in effect. + + Return False during handshake. + """ + return self._state == _WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != _UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = _DO_HANDSHAKE + self._handshake_cb = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. + + Return a list of ssldata. A ssldata element is a list of buffers - if not hi >= lo >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (hi, lo)) + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == _UNWRAPPED: + raise RuntimeError('no security layer present') + if self._state == _SHUTDOWN: + raise RuntimeError('shutdown in progress') + assert self._state in (_WRAPPED, _DO_HANDSHAKE) + self._state = _SHUTDOWN + self._shutdown_cb = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling shutdown(). + """ + if self._state == _UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + if data: + appdata = [data] + else: + appdata = [] + return ([], appdata) + + self._need_ssldata = False + if data: + self._incoming.write(data) + + ssldata = [] + appdata = [] + try: + if self._state == _DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = _WRAPPED + if self._handshake_cb: + self._handshake_cb(None) + if only_handshake: + return (ssldata, appdata) + # Handshake done: execute the wrapped block + + if self._state == _WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.max_size) + appdata.append(chunk) + if not chunk: # close_notify + break + + elif self._state == _SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = _UNWRAPPED + if self._shutdown_cb: + self._shutdown_cb() + + elif self._state == _UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as exc: + exc_errno = getattr(exc, 'errno', None) + if exc_errno not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == _DO_HANDSHAKE and self._handshake_cb: + self._handshake_cb(exc) + raise + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the id() must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + assert 0 <= offset <= len(data) + if self._state == _UNWRAPPED: + # pass through data in unwrapped mode + if offset < len(data): + ssldata = [data[offset:]] + else: + ssldata = [] + return (ssldata, len(data)) - return hi, lo + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as exc: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + exc_errno = getattr(exc, 'errno', None) + if exc.reason == 'PROTOCOL_IS_SHUTDOWN': + exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ + if exc_errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): - _start_tls_compatible = True _sendfile_compatible = constants._SendfileMode.FALLBACK def __init__(self, loop, ssl_protocol): self._loop = loop + # SSLProtocol instance self._ssl_protocol = ssl_protocol self._closed = False @@ -110,15 +315,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin, self._closed = True self._ssl_protocol._start_shutdown() - def __del__(self, _warnings=warnings): + def __del__(self, _warn=warnings.warn): if not self._closed: - self._closed = True - _warnings.warn( - "unclosed transport <asyncio._SSLProtocolTransport " - "object>", ResourceWarning) + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() def is_reading(self): - return not self._ssl_protocol._app_reading_paused + tr = self._ssl_protocol._transport + if tr is None: + raise RuntimeError('SSL transport has not been initialized yet') + return tr.is_reading() def pause_reading(self): """Pause the receiving end. @@ -126,7 +332,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin, No data will be passed to the protocol's data_received() method until resume_reading() is called. """ - self._ssl_protocol._pause_reading() + self._ssl_protocol._transport.pause_reading() def resume_reading(self): """Resume the receiving end. @@ -134,7 +340,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin, Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._resume_reading() + self._ssl_protocol._transport.resume_reading() def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -155,51 +361,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin, reduces opportunities for doing I/O and computation concurrently. """ - self._ssl_protocol._set_write_buffer_limits(high, low) - self._ssl_protocol._control_app_writing() - - def get_write_buffer_limits(self): - return (self._ssl_protocol._outgoing_low_water, - self._ssl_protocol._outgoing_high_water) + self._ssl_protocol._transport.set_write_buffer_limits(high, low) def get_write_buffer_size(self): - """Return the current size of the write buffers.""" - return self._ssl_protocol._get_write_buffer_size() - - def set_read_buffer_limits(self, high=None, low=None): - """Set the high- and low-water limits for read flow control. - - These two values control when to call the upstream transport's - pause_reading() and resume_reading() methods. If specified, - the low-water limit must be less than or equal to the - high-water limit. Neither value can be negative. - - The defaults are implementation-specific. If only the - high-water limit is given, the low-water limit defaults to an - implementation-specific value less than or equal to the - high-water limit. Setting high to zero forces low to zero as - well, and causes pause_reading() to be called whenever the - buffer becomes non-empty. Setting low to zero causes - resume_reading() to be called only once the buffer is empty. - Use of zero for either limit is generally sub-optimal as it - reduces opportunities for doing I/O and computation - concurrently. - """ - self._ssl_protocol._set_read_buffer_limits(high, low) - self._ssl_protocol._control_ssl_reading() - - def get_read_buffer_limits(self): - return (self._ssl_protocol._incoming_low_water, - self._ssl_protocol._incoming_high_water) - - def get_read_buffer_size(self): - """Return the current size of the read buffer.""" - return self._ssl_protocol._get_read_buffer_size() + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() @property def _protocol_paused(self): # Required for sendfile fallback pause_writing/resume_writing logic - return self._ssl_protocol._app_writing_paused + return self._ssl_protocol._transport._protocol_paused def write(self, data): """Write some data bytes to the transport. @@ -212,22 +383,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin, f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata((data,)) - - def writelines(self, list_of_data): - """Write a list (or any iterable) of data bytes to the transport. - - The default implementation concatenates the arguments and - calls write() on the result. - """ - self._ssl_protocol._write_appdata(list_of_data) - - def write_eof(self): - """Close the write end after flushing buffered data. - - This raises :exc:`NotImplementedError` right now. - """ - raise NotImplementedError + self._ssl_protocol._write_appdata(data) def can_write_eof(self): """Return True if this transport supports write_eof(), False if not.""" @@ -240,36 +396,23 @@ class _SSLProtocolTransport(transports._FlowControlMixin, The protocol's connection_lost() method will (eventually) be called with None as its argument. """ - self._closed = True self._ssl_protocol._abort() - - def _force_close(self, exc): self._closed = True - self._ssl_protocol._abort(exc) - def _test__append_write_backlog(self, data): - # for test only - self._ssl_protocol._write_backlog.append(data) - self._ssl_protocol._write_buffer_size += len(data) +class SSLProtocol(protocols.Protocol): + """SSL protocol. -class SSLProtocol(protocols.BufferedProtocol): - max_size = 256 * 1024 # Buffer size passed to read() - - _handshake_start_time = None - _handshake_timeout_handle = None - _shutdown_timeout_handle = None + Implementation of SSL on top of a socket using incoming and outgoing + buffers which are ssl.MemoryBIO objects. + """ def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): if ssl is None: - raise RuntimeError("stdlib ssl module not available") - - self._ssl_buffer = bytearray(self.max_size) - self._ssl_buffer_view = memoryview(self._ssl_buffer) + raise RuntimeError('stdlib ssl module not available') if ssl_handshake_timeout is None: ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT @@ -277,12 +420,6 @@ class SSLProtocol(protocols.BufferedProtocol): raise ValueError( f"ssl_handshake_timeout should be a positive number, " f"got {ssl_handshake_timeout}") - if ssl_shutdown_timeout is None: - ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT - elif ssl_shutdown_timeout <= 0: - raise ValueError( - f"ssl_shutdown_timeout should be a positive number, " - f"got {ssl_shutdown_timeout}") if not sslcontext: sslcontext = _create_transport_context( @@ -305,54 +442,21 @@ class SSLProtocol(protocols.BufferedProtocol): self._waiter = waiter self._loop = loop self._set_app_protocol(app_protocol) - self._app_transport = None - self._app_transport_created = False + self._app_transport = _SSLProtocolTransport(self._loop, self) + # _SSLPipe instance (None until the connection is made) + self._sslpipe = None + self._session_established = False + self._in_handshake = False + self._in_shutdown = False # transport, ex: SelectorSocketTransport self._transport = None + self._call_connection_made = call_connection_made self._ssl_handshake_timeout = ssl_handshake_timeout - self._ssl_shutdown_timeout = ssl_shutdown_timeout - # SSL and state machine - self._incoming = ssl.MemoryBIO() - self._outgoing = ssl.MemoryBIO() - self._state = SSLProtocolState.UNWRAPPED - self._conn_lost = 0 # Set when connection_lost called - if call_connection_made: - self._app_state = AppProtocolState.STATE_INIT - else: - self._app_state = AppProtocolState.STATE_CON_MADE - self._sslobj = self._sslcontext.wrap_bio( - self._incoming, self._outgoing, - server_side=self._server_side, - server_hostname=self._server_hostname) - - # Flow Control - - self._ssl_writing_paused = False - - self._app_reading_paused = False - - self._ssl_reading_paused = False - self._incoming_high_water = 0 - self._incoming_low_water = 0 - self._set_read_buffer_limits() - self._eof_received = False - - self._app_writing_paused = False - self._outgoing_high_water = 0 - self._outgoing_low_water = 0 - self._set_write_buffer_limits() - self._get_app_transport() def _set_app_protocol(self, app_protocol): self._app_protocol = app_protocol - # Make fast hasattr check first - if (hasattr(app_protocol, 'get_buffer') and - isinstance(app_protocol, protocols.BufferedProtocol)): - self._app_protocol_get_buffer = app_protocol.get_buffer - self._app_protocol_buffer_updated = app_protocol.buffer_updated - self._app_protocol_is_buffer = True - else: - self._app_protocol_is_buffer = False + self._app_protocol_is_buffer = \ + isinstance(app_protocol, protocols.BufferedProtocol) def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -364,20 +468,15 @@ class SSLProtocol(protocols.BufferedProtocol): self._waiter.set_result(None) self._waiter = None - def _get_app_transport(self): - if self._app_transport is None: - if self._app_transport_created: - raise RuntimeError('Creating _SSLProtocolTransport twice') - self._app_transport = _SSLProtocolTransport(self._loop, self) - self._app_transport_created = True - return self._app_transport - def connection_made(self, transport): """Called when the low-level connection is made. Start the SSL handshake. """ self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) self._start_handshake() def connection_lost(self, exc): @@ -387,58 +486,72 @@ class SSLProtocol(protocols.BufferedProtocol): meaning a regular EOF is received or the connection was aborted or closed). """ - self._write_backlog.clear() - self._outgoing.read() - self._conn_lost += 1 - - # Just mark the app transport as closed so that its __dealloc__ - # doesn't complain. - if self._app_transport is not None: - self._app_transport._closed = True - - if self._state != SSLProtocolState.DO_HANDSHAKE: - if ( - self._app_state == AppProtocolState.STATE_CON_MADE or - self._app_state == AppProtocolState.STATE_EOF - ): - self._app_state = AppProtocolState.STATE_CON_LOST - self._loop.call_soon(self._app_protocol.connection_lost, exc) - self._set_state(SSLProtocolState.UNWRAPPED) + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + else: + # Most likely an exception occurred while in SSL handshake. + # Just mark the app transport as closed so that its __del__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True self._transport = None self._app_transport = None - self._app_protocol = None + if getattr(self, '_handshake_timeout_handle', None): + self._handshake_timeout_handle.cancel() self._wakeup_waiter(exc) + self._app_protocol = None + self._sslpipe = None - if self._shutdown_timeout_handle: - self._shutdown_timeout_handle.cancel() - self._shutdown_timeout_handle = None - if self._handshake_timeout_handle: - self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() - def get_buffer(self, n): - want = n - if want <= 0 or want > self.max_size: - want = self.max_size - if len(self._ssl_buffer) < want: - self._ssl_buffer = bytearray(want) - self._ssl_buffer_view = memoryview(self._ssl_buffer) - return self._ssl_buffer_view + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() - def buffer_updated(self, nbytes): - self._incoming.write(self._ssl_buffer_view[:nbytes]) + def data_received(self, data): + """Called when some SSL data is received. - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._do_handshake() + The argument is a bytes object. + """ + if self._sslpipe is None: + # transport closing, sslpipe is destroyed + return - elif self._state == SSLProtocolState.WRAPPED: - self._do_read() + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + self._fatal_error(e, 'SSL error in data received') + return - elif self._state == SSLProtocolState.FLUSHING: - self._do_flush() + for chunk in ssldata: + self._transport.write(chunk) - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() + for chunk in appdata: + if chunk: + try: + if self._app_protocol_is_buffer: + protocols._feed_data_to_buffered_proto( + self._app_protocol, chunk) + else: + self._app_protocol.data_received(chunk) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as ex: + self._fatal_error( + ex, 'application protocol failed to receive SSL data') + return + else: + self._start_shutdown() + break def eof_received(self): """Called when the other end of the low-level stream @@ -448,32 +561,19 @@ class SSLProtocol(protocols.BufferedProtocol): will close itself. If it returns a true value, closing the transport is up to the protocol. """ - self._eof_received = True try: if self._loop.get_debug(): logger.debug("%r received EOF", self) - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._on_handshake_complete(ConnectionResetError) - - elif self._state == SSLProtocolState.WRAPPED: - self._set_state(SSLProtocolState.FLUSHING) - if self._app_reading_paused: - return True - else: - self._do_flush() - - elif self._state == SSLProtocolState.FLUSHING: - self._do_write() - self._set_state(SSLProtocolState.SHUTDOWN) - self._do_shutdown() + self._wakeup_waiter(ConnectionResetError) - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() - - except Exception: + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: self._transport.close() - raise def _get_extra_info(self, name, default=None): if name in self._extra: @@ -483,45 +583,19 @@ class SSLProtocol(protocols.BufferedProtocol): else: return default - def _set_state(self, new_state): - allowed = False - - if new_state == SSLProtocolState.UNWRAPPED: - allowed = True - - elif ( - self._state == SSLProtocolState.UNWRAPPED and - new_state == SSLProtocolState.DO_HANDSHAKE - ): - allowed = True - - elif ( - self._state == SSLProtocolState.DO_HANDSHAKE and - new_state == SSLProtocolState.WRAPPED - ): - allowed = True - - elif ( - self._state == SSLProtocolState.WRAPPED and - new_state == SSLProtocolState.FLUSHING - ): - allowed = True - - elif ( - self._state == SSLProtocolState.FLUSHING and - new_state == SSLProtocolState.SHUTDOWN - ): - allowed = True - - if allowed: - self._state = new_state - + def _start_shutdown(self): + if self._in_shutdown: + return + if self._in_handshake: + self._abort() else: - raise RuntimeError( - 'cannot switch state from {} to {}'.format( - self._state, new_state)) + self._in_shutdown = True + self._write_appdata(b'') - # Handshake flow + def _write_appdata(self, data): + self._write_backlog.append((data, 0)) + self._write_buffer_size += len(data) + self._process_write_backlog() def _start_handshake(self): if self._loop.get_debug(): @@ -529,18 +603,17 @@ class SSLProtocol(protocols.BufferedProtocol): self._handshake_start_time = self._loop.time() else: self._handshake_start_time = None - - self._set_state(SSLProtocolState.DO_HANDSHAKE) - - # start handshake timeout count down + self._in_handshake = True + # (b'', 1) is a special value in _process_write_backlog() to do + # the SSL handshake + self._write_backlog.append((b'', 1)) self._handshake_timeout_handle = \ self._loop.call_later(self._ssl_handshake_timeout, - lambda: self._check_handshake_timeout()) - - self._do_handshake() + self._check_handshake_timeout) + self._process_write_backlog() def _check_handshake_timeout(self): - if self._state == SSLProtocolState.DO_HANDSHAKE: + if self._in_handshake is True: msg = ( f"SSL handshake is taking longer than " f"{self._ssl_handshake_timeout} seconds: " @@ -548,37 +621,24 @@ class SSLProtocol(protocols.BufferedProtocol): ) self._fatal_error(ConnectionAbortedError(msg)) - def _do_handshake(self): - try: - self._sslobj.do_handshake() - except SSLAgainErrors: - self._process_outgoing() - except ssl.SSLError as exc: - self._on_handshake_complete(exc) - else: - self._on_handshake_complete(None) - def _on_handshake_complete(self, handshake_exc): - if self._handshake_timeout_handle is not None: - self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + self._in_handshake = False + self._handshake_timeout_handle.cancel() - sslobj = self._sslobj + sslobj = self._sslpipe.ssl_object try: - if handshake_exc is None: - self._set_state(SSLProtocolState.WRAPPED) - else: + if handshake_exc is not None: raise handshake_exc peercert = sslobj.getpeercert() - except Exception as exc: - self._set_state(SSLProtocolState.UNWRAPPED) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if isinstance(exc, ssl.CertificateError): msg = 'SSL handshake failed on verifying the certificate' else: msg = 'SSL handshake failed' self._fatal_error(exc, msg) - self._wakeup_waiter(exc) return if self._loop.get_debug(): @@ -589,330 +649,85 @@ class SSLProtocol(protocols.BufferedProtocol): self._extra.update(peercert=peercert, cipher=sslobj.cipher(), compression=sslobj.compression(), - ssl_object=sslobj) - if self._app_state == AppProtocolState.STATE_INIT: - self._app_state = AppProtocolState.STATE_CON_MADE - self._app_protocol.connection_made(self._get_app_transport()) + ssl_object=sslobj, + ) + if self._call_connection_made: + self._app_protocol.connection_made(self._app_transport) self._wakeup_waiter() - self._do_read() - - # Shutdown flow - - def _start_shutdown(self): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN, - SSLProtocolState.UNWRAPPED - ) - ): - return - if self._app_transport is not None: - self._app_transport._closed = True - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._abort() - else: - self._set_state(SSLProtocolState.FLUSHING) - self._shutdown_timeout_handle = self._loop.call_later( - self._ssl_shutdown_timeout, - lambda: self._check_shutdown_timeout() - ) - self._do_flush() - - def _check_shutdown_timeout(self): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN - ) - ): - self._transport._force_close( - exceptions.TimeoutError('SSL shutdown timed out')) - - def _do_flush(self): - self._do_read() - self._set_state(SSLProtocolState.SHUTDOWN) - self._do_shutdown() - - def _do_shutdown(self): - try: - if not self._eof_received: - self._sslobj.unwrap() - except SSLAgainErrors: - self._process_outgoing() - except ssl.SSLError as exc: - self._on_shutdown_complete(exc) - else: - self._process_outgoing() - self._call_eof_received() - self._on_shutdown_complete(None) - - def _on_shutdown_complete(self, shutdown_exc): - if self._shutdown_timeout_handle is not None: - self._shutdown_timeout_handle.cancel() - self._shutdown_timeout_handle = None - - if shutdown_exc: - self._fatal_error(shutdown_exc) - else: - self._loop.call_soon(self._transport.close) - - def _abort(self): - self._set_state(SSLProtocolState.UNWRAPPED) - if self._transport is not None: - self._transport.abort() - - # Outgoing flow - - def _write_appdata(self, list_of_data): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN, - SSLProtocolState.UNWRAPPED - ) - ): - if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - logger.warning('SSL connection is closed') - self._conn_lost += 1 + self._session_established = True + # In case transport.write() was already called. Don't call + # immediately _process_write_backlog(), but schedule it: + # _on_handshake_complete() can be called indirectly from + # _process_write_backlog(), and _process_write_backlog() is not + # reentrant. + self._loop.call_soon(self._process_write_backlog) + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None or self._sslpipe is None: return - for data in list_of_data: - self._write_backlog.append(data) - self._write_buffer_size += len(data) - try: - if self._state == SSLProtocolState.WRAPPED: - self._do_write() - - except Exception as ex: - self._fatal_error(ex, 'Fatal error on SSL protocol') - - def _do_write(self): - try: - while self._write_backlog: - data = self._write_backlog[0] - count = self._sslobj.write(data) - data_len = len(data) - if count < data_len: - self._write_backlog[0] = data[count:] - self._write_buffer_size -= count + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata = self._sslpipe.do_handshake( + self._on_handshake_complete) + offset = 1 else: - del self._write_backlog[0] - self._write_buffer_size -= data_len - except SSLAgainErrors: - pass - self._process_outgoing() - - def _process_outgoing(self): - if not self._ssl_writing_paused: - data = self._outgoing.read() - if len(data): - self._transport.write(data) - self._control_app_writing() - - # Incoming flow - - def _do_read(self): - if ( - self._state not in ( - SSLProtocolState.WRAPPED, - SSLProtocolState.FLUSHING, - ) - ): - return - try: - if not self._app_reading_paused: - if self._app_protocol_is_buffer: - self._do_read__buffered() - else: - self._do_read__copied() - if self._write_backlog: - self._do_write() - else: - self._process_outgoing() - self._control_ssl_reading() - except Exception as ex: - self._fatal_error(ex, 'Fatal error on SSL protocol') - - def _do_read__buffered(self): - offset = 0 - count = 1 - - buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) - wants = len(buf) - - try: - count = self._sslobj.read(wants, buf) - - if count > 0: - offset = count - while offset < wants: - count = self._sslobj.read(wants - offset, buf[offset:]) - if count > 0: - offset += count - else: - break - else: - self._loop.call_soon(lambda: self._do_read()) - except SSLAgainErrors: - pass - if offset > 0: - self._app_protocol_buffer_updated(offset) - if not count: - # close_notify - self._call_eof_received() - self._start_shutdown() - - def _do_read__copied(self): - chunk = b'1' - zero = True - one = False - - try: - while True: - chunk = self._sslobj.read(self.max_size) - if not chunk: + ssldata = self._sslpipe.shutdown(self._finalize) + offset = 1 + + for chunk in ssldata: + self._transport.write(chunk) + + if offset < len(data): + self._write_backlog[0] = (data, offset) + # A short write means that a write is blocked on a read + # We need to enable reading if it is paused! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() break - if zero: - zero = False - one = True - first = chunk - elif one: - one = False - data = [first, chunk] - else: - data.append(chunk) - except SSLAgainErrors: - pass - if one: - self._app_protocol.data_received(first) - elif not zero: - self._app_protocol.data_received(b''.join(data)) - if not chunk: - # close_notify - self._call_eof_received() - self._start_shutdown() - - def _call_eof_received(self): - try: - if self._app_state == AppProtocolState.STATE_CON_MADE: - self._app_state = AppProtocolState.STATE_EOF - keep_open = self._app_protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._fatal_error(ex, 'Error calling eof_received()') - - # Flow control for writes from APP socket - def _control_app_writing(self): - size = self._get_write_buffer_size() - if size >= self._outgoing_high_water and not self._app_writing_paused: - self._app_writing_paused = True - try: - self._app_protocol.pause_writing() - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.pause_writing() failed', - 'exception': exc, - 'transport': self._app_transport, - 'protocol': self, - }) - elif size <= self._outgoing_low_water and self._app_writing_paused: - self._app_writing_paused = False - try: - self._app_protocol.resume_writing() - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.resume_writing() failed', - 'exception': exc, - 'transport': self._app_transport, - 'protocol': self, - }) - - def _get_write_buffer_size(self): - return self._outgoing.pending + self._write_buffer_size - - def _set_write_buffer_limits(self, high=None, low=None): - high, low = add_flowcontrol_defaults( - high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) - self._outgoing_high_water = high - self._outgoing_low_water = low - - # Flow control for reads to APP socket - - def _pause_reading(self): - self._app_reading_paused = True - - def _resume_reading(self): - if self._app_reading_paused: - self._app_reading_paused = False - - def resume(): - if self._state == SSLProtocolState.WRAPPED: - self._do_read() - elif self._state == SSLProtocolState.FLUSHING: - self._do_flush() - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() - self._loop.call_soon(resume) - - # Flow control for reads from SSL socket - - def _control_ssl_reading(self): - size = self._get_read_buffer_size() - if size >= self._incoming_high_water and not self._ssl_reading_paused: - self._ssl_reading_paused = True - self._transport.pause_reading() - elif size <= self._incoming_low_water and self._ssl_reading_paused: - self._ssl_reading_paused = False - self._transport.resume_reading() - - def _set_read_buffer_limits(self, high=None, low=None): - high, low = add_flowcontrol_defaults( - high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) - self._incoming_high_water = high - self._incoming_low_water = low - - def _get_read_buffer_size(self): - return self._incoming.pending - - # Flow control for writes to SSL socket - - def pause_writing(self): - """Called when the low-level transport's buffer goes over - the high-water mark. - """ - assert not self._ssl_writing_paused - self._ssl_writing_paused = True - - def resume_writing(self): - """Called when the low-level transport's buffer drains below - the low-water mark. - """ - assert self._ssl_writing_paused - self._ssl_writing_paused = False - self._process_outgoing() + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= len(data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + if self._in_handshake: + # Exceptions will be re-raised in _on_handshake_complete. + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') def _fatal_error(self, exc, message='Fatal error on transport'): - if self._transport: - self._transport._force_close(exc) - if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) - elif not isinstance(exc, exceptions.CancelledError): + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, 'transport': self._transport, 'protocol': self, }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + self._sslpipe = None + + if self._transport is not None: + self._transport.close() + + def _abort(self): + try: + if self._transport is not None: + self._transport.abort() + finally: + self._finalize() |