summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKumar Aditya <59607654+kumaraditya303@users.noreply.github.com>2022-02-15 13:04:00 (GMT)
committerGitHub <noreply@github.com>2022-02-15 13:04:00 (GMT)
commit13c10bfb777483c7b02877aab029345a056b809c (patch)
tree4a94952a81baef1c7ceef4edc5f5d5cc6e33e2e9
parent3be1a443ca8e7d4ba85f95b78df5c4122cae9ede (diff)
downloadcpython-13c10bfb777483c7b02877aab029345a056b809c.zip
cpython-13c10bfb777483c7b02877aab029345a056b809c.tar.gz
cpython-13c10bfb777483c7b02877aab029345a056b809c.tar.bz2
bpo-44011: New asyncio ssl implementation (#31275)
* bpo-44011: New asyncio ssl implementation Co-Authored-By: Andrew Svetlov <andrew.svetlov@gmail.com> * fix warning * fix typo Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
-rw-r--r--Lib/asyncio/base_events.py43
-rw-r--r--Lib/asyncio/constants.py7
-rw-r--r--Lib/asyncio/events.py19
-rw-r--r--Lib/asyncio/proactor_events.py12
-rw-r--r--Lib/asyncio/selector_events.py31
-rw-r--r--Lib/asyncio/sslproto.py1059
-rw-r--r--Lib/asyncio/unix_events.py17
-rw-r--r--Lib/test/test_asyncio/test_base_events.py21
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py38
-rw-r--r--Lib/test/test_asyncio/test_ssl.py1721
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py35
-rw-r--r--Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst2
12 files changed, 2478 insertions, 527 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 56ea7ba..703c8a4 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -269,7 +269,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):
class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
- ssl_handshake_timeout):
+ ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
@@ -278,6 +278,7 @@ class Server(events.AbstractServer):
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
+ self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
@@ -309,7 +310,8 @@ class Server(events.AbstractServer):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
- self, self._backlog, self._ssl_handshake_timeout)
+ self, self._backlog, self._ssl_handshake_timeout,
+ self._ssl_shutdown_timeout)
def get_loop(self):
return self._loop
@@ -463,6 +465,7 @@ class BaseEventLoop(events.AbstractEventLoop):
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
@@ -965,6 +968,7 @@ class BaseEventLoop(events.AbstractEventLoop):
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
@@ -1000,6 +1004,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and not ssl:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
@@ -1075,7 +1083,8 @@ class BaseEventLoop(events.AbstractEventLoop):
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
@@ -1087,7 +1096,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
sock.setblocking(False)
@@ -1098,7 +1108,8 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
@@ -1189,7 +1200,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
@@ -1212,6 +1224,7 @@ class BaseEventLoop(events.AbstractEventLoop):
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)
# Pause early so that "ssl_protocol.data_received()" doesn't
@@ -1397,6 +1410,7 @@ class BaseEventLoop(events.AbstractEventLoop):
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
@@ -1420,6 +1434,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and ssl is None:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
if host is not None or port is not None:
if sock is not None:
raise ValueError(
@@ -1492,7 +1510,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
server = Server(self, sockets, protocol_factory,
- ssl, backlog, ssl_handshake_timeout)
+ ssl, backlog, ssl_handshake_timeout,
+ ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
@@ -1506,7 +1525,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
@@ -1514,9 +1534,14 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and not ssl:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py
index 33feed6..f171ead 100644
--- a/Lib/asyncio/constants.py
+++ b/Lib/asyncio/constants.py
@@ -15,10 +15,17 @@ DEBUG_STACK_DEPTH = 10
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0
+# Number of seconds to wait for SSL shutdown to complete
+# The default timeout mimics lingering_time
+SSL_SHUTDOWN_TIMEOUT = 30.0
+
# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256
+FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
+FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB
+
# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index e3c55b2..1d305e3 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -303,6 +303,7 @@ class AbstractEventLoop:
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError
@@ -312,6 +313,7 @@ class AbstractEventLoop:
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.
@@ -352,6 +354,10 @@ class AbstractEventLoop:
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.
+ ssl_shutdown_timeout is the time in seconds that an SSL server
+ will wait for completion of the SSL shutdown procedure
+ before aborting the connection. Default is 30s.
+
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -370,7 +376,8 @@ class AbstractEventLoop:
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
@@ -382,13 +389,15 @@ class AbstractEventLoop:
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
raise NotImplementedError
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.
@@ -410,6 +419,9 @@ class AbstractEventLoop:
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).
+ ssl_shutdown_timeout is the time in seconds that an SSL server
+ will wait for the SSL shutdown to finish (defaults to 30s).
+
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -420,7 +432,8 @@ class AbstractEventLoop:
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
"""Handle an accepted connection.
This is used by servers that accept connections outside of
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 1d9e2fe..ae59f30 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -642,11 +642,13 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -812,7 +814,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
def loop(f=None):
try:
@@ -826,7 +829,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
self._make_socket_transport(
conn, protocol,
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 59cb6b1..63ab15f 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
+ ):
ssl_protocol = sslproto.SSLProtocol(
- self, protocol, sslcontext, waiter,
- server_side, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ self, protocol, sslcontext, waiter,
+ server_side, server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout
+ )
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -146,15 +150,17 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
- ssl_handshake_timeout)
+ ssl_handshake_timeout, ssl_shutdown_timeout)
def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
@@ -185,20 +191,22 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
- backlog, ssl_handshake_timeout)
+ backlog, ssl_handshake_timeout,
+ ssl_shutdown_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
- ssl_handshake_timeout)
+ ssl_handshake_timeout, ssl_shutdown_timeout)
self.create_task(accept)
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
protocol = None
transport = None
try:
@@ -208,7 +216,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index 00fc16c..de7c333 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -1,4 +1,5 @@
import collections
+import enum
import warnings
try:
import ssl
@@ -6,10 +7,38 @@ except ImportError: # pragma: no cover
ssl = None
from . import constants
+from . import exceptions
from . import protocols
from . import transports
from .log import logger
+if ssl is not None:
+ SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
+
+
+class SSLProtocolState(enum.Enum):
+ UNWRAPPED = "UNWRAPPED"
+ DO_HANDSHAKE = "DO_HANDSHAKE"
+ WRAPPED = "WRAPPED"
+ FLUSHING = "FLUSHING"
+ SHUTDOWN = "SHUTDOWN"
+
+
+class AppProtocolState(enum.Enum):
+ # This tracks the state of app protocol (https://git.io/fj59P):
+ #
+ # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
+ #
+ # * cm: connection_made()
+ # * dr: data_received()
+ # * er: eof_received()
+ # * cl: connection_lost()
+
+ STATE_INIT = "STATE_INIT"
+ STATE_CON_MADE = "STATE_CON_MADE"
+ STATE_EOF = "STATE_EOF"
+ STATE_CON_LOST = "STATE_CON_LOST"
+
def _create_transport_context(server_side, server_hostname):
if server_side:
@@ -25,269 +54,35 @@ def _create_transport_context(server_side, server_hostname):
return sslcontext
-# States of an _SSLPipe.
-_UNWRAPPED = "UNWRAPPED"
-_DO_HANDSHAKE = "DO_HANDSHAKE"
-_WRAPPED = "WRAPPED"
-_SHUTDOWN = "SHUTDOWN"
-
-
-class _SSLPipe(object):
- """An SSL "Pipe".
-
- An SSL pipe allows you to communicate with an SSL/TLS protocol instance
- through memory buffers. It can be used to implement a security layer for an
- existing connection where you don't have access to the connection's file
- descriptor, or for some reason you don't want to use it.
-
- An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
- data is passed through untransformed. In wrapped mode, application level
- data is encrypted to SSL record level data and vice versa. The SSL record
- level is the lowest level in the SSL protocol suite and is what travels
- as-is over the wire.
-
- An SslPipe initially is in "unwrapped" mode. To start SSL, call
- do_handshake(). To shutdown SSL again, call unwrap().
- """
-
- max_size = 256 * 1024 # Buffer size passed to read()
-
- def __init__(self, context, server_side, server_hostname=None):
- """
- The *context* argument specifies the ssl.SSLContext to use.
-
- The *server_side* argument indicates whether this is a server side or
- client side transport.
-
- The optional *server_hostname* argument can be used to specify the
- hostname you are connecting to. You may only specify this parameter if
- the _ssl module supports Server Name Indication (SNI).
- """
- self._context = context
- self._server_side = server_side
- self._server_hostname = server_hostname
- self._state = _UNWRAPPED
- self._incoming = ssl.MemoryBIO()
- self._outgoing = ssl.MemoryBIO()
- self._sslobj = None
- self._need_ssldata = False
- self._handshake_cb = None
- self._shutdown_cb = None
-
- @property
- def context(self):
- """The SSL context passed to the constructor."""
- return self._context
-
- @property
- def ssl_object(self):
- """The internal ssl.SSLObject instance.
-
- Return None if the pipe is not wrapped.
- """
- return self._sslobj
-
- @property
- def need_ssldata(self):
- """Whether more record level data is needed to complete a handshake
- that is currently in progress."""
- return self._need_ssldata
-
- @property
- def wrapped(self):
- """
- Whether a security layer is currently in effect.
-
- Return False during handshake.
- """
- return self._state == _WRAPPED
-
- def do_handshake(self, callback=None):
- """Start the SSL handshake.
-
- Return a list of ssldata. A ssldata element is a list of buffers
-
- The optional *callback* argument can be used to install a callback that
- will be called when the handshake is complete. The callback will be
- called with None if successful, else an exception instance.
- """
- if self._state != _UNWRAPPED:
- raise RuntimeError('handshake in progress or completed')
- self._sslobj = self._context.wrap_bio(
- self._incoming, self._outgoing,
- server_side=self._server_side,
- server_hostname=self._server_hostname)
- self._state = _DO_HANDSHAKE
- self._handshake_cb = callback
- ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
- assert len(appdata) == 0
- return ssldata
-
- def shutdown(self, callback=None):
- """Start the SSL shutdown sequence.
-
- Return a list of ssldata. A ssldata element is a list of buffers
-
- The optional *callback* argument can be used to install a callback that
- will be called when the shutdown is complete. The callback will be
- called without arguments.
- """
- if self._state == _UNWRAPPED:
- raise RuntimeError('no security layer present')
- if self._state == _SHUTDOWN:
- raise RuntimeError('shutdown in progress')
- assert self._state in (_WRAPPED, _DO_HANDSHAKE)
- self._state = _SHUTDOWN
- self._shutdown_cb = callback
- ssldata, appdata = self.feed_ssldata(b'')
- assert appdata == [] or appdata == [b'']
- return ssldata
-
- def feed_eof(self):
- """Send a potentially "ragged" EOF.
-
- This method will raise an SSL_ERROR_EOF exception if the EOF is
- unexpected.
- """
- self._incoming.write_eof()
- ssldata, appdata = self.feed_ssldata(b'')
- assert appdata == [] or appdata == [b'']
-
- def feed_ssldata(self, data, only_handshake=False):
- """Feed SSL record level data into the pipe.
-
- The data must be a bytes instance. It is OK to send an empty bytes
- instance. This can be used to get ssldata for a handshake initiated by
- this endpoint.
-
- Return a (ssldata, appdata) tuple. The ssldata element is a list of
- buffers containing SSL data that needs to be sent to the remote SSL.
-
- The appdata element is a list of buffers containing plaintext data that
- needs to be forwarded to the application. The appdata list may contain
- an empty buffer indicating an SSL "close_notify" alert. This alert must
- be acknowledged by calling shutdown().
- """
- if self._state == _UNWRAPPED:
- # If unwrapped, pass plaintext data straight through.
- if data:
- appdata = [data]
- else:
- appdata = []
- return ([], appdata)
+def add_flowcontrol_defaults(high, low, kb):
+ if high is None:
+ if low is None:
+ hi = kb * 1024
+ else:
+ lo = low
+ hi = 4 * lo
+ else:
+ hi = high
+ if low is None:
+ lo = hi // 4
+ else:
+ lo = low
- self._need_ssldata = False
- if data:
- self._incoming.write(data)
+ if not hi >= lo >= 0:
+ raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+ (hi, lo))
- ssldata = []
- appdata = []
- try:
- if self._state == _DO_HANDSHAKE:
- # Call do_handshake() until it doesn't raise anymore.
- self._sslobj.do_handshake()
- self._state = _WRAPPED
- if self._handshake_cb:
- self._handshake_cb(None)
- if only_handshake:
- return (ssldata, appdata)
- # Handshake done: execute the wrapped block
-
- if self._state == _WRAPPED:
- # Main state: read data from SSL until close_notify
- while True:
- chunk = self._sslobj.read(self.max_size)
- appdata.append(chunk)
- if not chunk: # close_notify
- break
-
- elif self._state == _SHUTDOWN:
- # Call shutdown() until it doesn't raise anymore.
- self._sslobj.unwrap()
- self._sslobj = None
- self._state = _UNWRAPPED
- if self._shutdown_cb:
- self._shutdown_cb()
-
- elif self._state == _UNWRAPPED:
- # Drain possible plaintext data after close_notify.
- appdata.append(self._incoming.read())
- except (ssl.SSLError, ssl.CertificateError) as exc:
- exc_errno = getattr(exc, 'errno', None)
- if exc_errno not in (
- ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
- ssl.SSL_ERROR_SYSCALL):
- if self._state == _DO_HANDSHAKE and self._handshake_cb:
- self._handshake_cb(exc)
- raise
- self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
- # Check for record level data that needs to be sent back.
- # Happens for the initial handshake and renegotiations.
- if self._outgoing.pending:
- ssldata.append(self._outgoing.read())
- return (ssldata, appdata)
-
- def feed_appdata(self, data, offset=0):
- """Feed plaintext data into the pipe.
-
- Return an (ssldata, offset) tuple. The ssldata element is a list of
- buffers containing record level data that needs to be sent to the
- remote SSL instance. The offset is the number of plaintext bytes that
- were processed, which may be less than the length of data.
-
- NOTE: In case of short writes, this call MUST be retried with the SAME
- buffer passed into the *data* argument (i.e. the id() must be the
- same). This is an OpenSSL requirement. A further particularity is that
- a short write will always have offset == 0, because the _ssl module
- does not enable partial writes. And even though the offset is zero,
- there will still be encrypted data in ssldata.
- """
- assert 0 <= offset <= len(data)
- if self._state == _UNWRAPPED:
- # pass through data in unwrapped mode
- if offset < len(data):
- ssldata = [data[offset:]]
- else:
- ssldata = []
- return (ssldata, len(data))
-
- ssldata = []
- view = memoryview(data)
- while True:
- self._need_ssldata = False
- try:
- if offset < len(view):
- offset += self._sslobj.write(view[offset:])
- except ssl.SSLError as exc:
- # It is not allowed to call write() after unwrap() until the
- # close_notify is acknowledged. We return the condition to the
- # caller as a short write.
- exc_errno = getattr(exc, 'errno', None)
- if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
- exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
- if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE,
- ssl.SSL_ERROR_SYSCALL):
- raise
- self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
- # See if there's any record level data back for us.
- if self._outgoing.pending:
- ssldata.append(self._outgoing.read())
- if offset == len(view) or self._need_ssldata:
- break
- return (ssldata, offset)
+ return hi, lo
class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport):
+ _start_tls_compatible = True
_sendfile_compatible = constants._SendfileMode.FALLBACK
def __init__(self, loop, ssl_protocol):
self._loop = loop
- # SSLProtocol instance
self._ssl_protocol = ssl_protocol
self._closed = False
@@ -315,16 +110,15 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
self._closed = True
self._ssl_protocol._start_shutdown()
- def __del__(self, _warn=warnings.warn):
+ def __del__(self, _warnings=warnings):
if not self._closed:
- _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
- self.close()
+ self._closed = True
+ _warnings.warn(
+ "unclosed transport <asyncio._SSLProtocolTransport "
+ "object>", ResourceWarning)
def is_reading(self):
- tr = self._ssl_protocol._transport
- if tr is None:
- raise RuntimeError('SSL transport has not been initialized yet')
- return tr.is_reading()
+ return not self._ssl_protocol._app_reading_paused
def pause_reading(self):
"""Pause the receiving end.
@@ -332,7 +126,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
No data will be passed to the protocol's data_received()
method until resume_reading() is called.
"""
- self._ssl_protocol._transport.pause_reading()
+ self._ssl_protocol._pause_reading()
def resume_reading(self):
"""Resume the receiving end.
@@ -340,7 +134,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
Data received will once again be passed to the protocol's
data_received() method.
"""
- self._ssl_protocol._transport.resume_reading()
+ self._ssl_protocol._resume_reading()
def set_write_buffer_limits(self, high=None, low=None):
"""Set the high- and low-water limits for write flow control.
@@ -361,11 +155,46 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
reduces opportunities for doing I/O and computation
concurrently.
"""
- self._ssl_protocol._transport.set_write_buffer_limits(high, low)
+ self._ssl_protocol._set_write_buffer_limits(high, low)
+ self._ssl_protocol._control_app_writing()
+
+ def get_write_buffer_limits(self):
+ return (self._ssl_protocol._outgoing_low_water,
+ self._ssl_protocol._outgoing_high_water)
def get_write_buffer_size(self):
- """Return the current size of the write buffer."""
- return self._ssl_protocol._transport.get_write_buffer_size()
+ """Return the current size of the write buffers."""
+ return self._ssl_protocol._get_write_buffer_size()
+
+ def set_read_buffer_limits(self, high=None, low=None):
+ """Set the high- and low-water limits for read flow control.
+
+ These two values control when to call the upstream transport's
+ pause_reading() and resume_reading() methods. If specified,
+ the low-water limit must be less than or equal to the
+ high-water limit. Neither value can be negative.
+
+ The defaults are implementation-specific. If only the
+ high-water limit is given, the low-water limit defaults to an
+ implementation-specific value less than or equal to the
+ high-water limit. Setting high to zero forces low to zero as
+ well, and causes pause_reading() to be called whenever the
+ buffer becomes non-empty. Setting low to zero causes
+ resume_reading() to be called only once the buffer is empty.
+ Use of zero for either limit is generally sub-optimal as it
+ reduces opportunities for doing I/O and computation
+ concurrently.
+ """
+ self._ssl_protocol._set_read_buffer_limits(high, low)
+ self._ssl_protocol._control_ssl_reading()
+
+ def get_read_buffer_limits(self):
+ return (self._ssl_protocol._incoming_low_water,
+ self._ssl_protocol._incoming_high_water)
+
+ def get_read_buffer_size(self):
+ """Return the current size of the read buffer."""
+ return self._ssl_protocol._get_read_buffer_size()
def get_write_buffer_limits(self):
"""Get the high and low watermarks for write flow control.
@@ -376,7 +205,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
@property
def _protocol_paused(self):
# Required for sendfile fallback pause_writing/resume_writing logic
- return self._ssl_protocol._transport._protocol_paused
+ return self._ssl_protocol._app_writing_paused
def write(self, data):
"""Write some data bytes to the transport.
@@ -389,7 +218,22 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
f"got {type(data).__name__}")
if not data:
return
- self._ssl_protocol._write_appdata(data)
+ self._ssl_protocol._write_appdata((data,))
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data bytes to the transport.
+
+ The default implementation concatenates the arguments and
+ calls write() on the result.
+ """
+ self._ssl_protocol._write_appdata(list_of_data)
+
+ def write_eof(self):
+ """Close the write end after flushing buffered data.
+
+ This raises :exc:`NotImplementedError` right now.
+ """
+ raise NotImplementedError
def can_write_eof(self):
"""Return True if this transport supports write_eof(), False if not."""
@@ -402,23 +246,36 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
The protocol's connection_lost() method will (eventually) be
called with None as its argument.
"""
+ self._closed = True
self._ssl_protocol._abort()
+
+ def _force_close(self, exc):
self._closed = True
+ self._ssl_protocol._abort(exc)
+ def _test__append_write_backlog(self, data):
+ # for test only
+ self._ssl_protocol._write_backlog.append(data)
+ self._ssl_protocol._write_buffer_size += len(data)
-class SSLProtocol(protocols.Protocol):
- """SSL protocol.
- Implementation of SSL on top of a socket using incoming and outgoing
- buffers which are ssl.MemoryBIO objects.
- """
+class SSLProtocol(protocols.BufferedProtocol):
+ max_size = 256 * 1024 # Buffer size passed to read()
+
+ _handshake_start_time = None
+ _handshake_timeout_handle = None
+ _shutdown_timeout_handle = None
def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
call_connection_made=True,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
if ssl is None:
- raise RuntimeError('stdlib ssl module not available')
+ raise RuntimeError("stdlib ssl module not available")
+
+ self._ssl_buffer = bytearray(self.max_size)
+ self._ssl_buffer_view = memoryview(self._ssl_buffer)
if ssl_handshake_timeout is None:
ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
@@ -426,6 +283,12 @@ class SSLProtocol(protocols.Protocol):
raise ValueError(
f"ssl_handshake_timeout should be a positive number, "
f"got {ssl_handshake_timeout}")
+ if ssl_shutdown_timeout is None:
+ ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
+ elif ssl_shutdown_timeout <= 0:
+ raise ValueError(
+ f"ssl_shutdown_timeout should be a positive number, "
+ f"got {ssl_shutdown_timeout}")
if not sslcontext:
sslcontext = _create_transport_context(
@@ -448,21 +311,54 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter
self._loop = loop
self._set_app_protocol(app_protocol)
- self._app_transport = _SSLProtocolTransport(self._loop, self)
- # _SSLPipe instance (None until the connection is made)
- self._sslpipe = None
- self._session_established = False
- self._in_handshake = False
- self._in_shutdown = False
+ self._app_transport = None
+ self._app_transport_created = False
# transport, ex: SelectorSocketTransport
self._transport = None
- self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
+ self._ssl_shutdown_timeout = ssl_shutdown_timeout
+ # SSL and state machine
+ self._incoming = ssl.MemoryBIO()
+ self._outgoing = ssl.MemoryBIO()
+ self._state = SSLProtocolState.UNWRAPPED
+ self._conn_lost = 0 # Set when connection_lost called
+ if call_connection_made:
+ self._app_state = AppProtocolState.STATE_INIT
+ else:
+ self._app_state = AppProtocolState.STATE_CON_MADE
+ self._sslobj = self._sslcontext.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=self._server_side,
+ server_hostname=self._server_hostname)
+
+ # Flow Control
+
+ self._ssl_writing_paused = False
+
+ self._app_reading_paused = False
+
+ self._ssl_reading_paused = False
+ self._incoming_high_water = 0
+ self._incoming_low_water = 0
+ self._set_read_buffer_limits()
+ self._eof_received = False
+
+ self._app_writing_paused = False
+ self._outgoing_high_water = 0
+ self._outgoing_low_water = 0
+ self._set_write_buffer_limits()
+ self._get_app_transport()
def _set_app_protocol(self, app_protocol):
self._app_protocol = app_protocol
- self._app_protocol_is_buffer = \
- isinstance(app_protocol, protocols.BufferedProtocol)
+ # Make fast hasattr check first
+ if (hasattr(app_protocol, 'get_buffer') and
+ isinstance(app_protocol, protocols.BufferedProtocol)):
+ self._app_protocol_get_buffer = app_protocol.get_buffer
+ self._app_protocol_buffer_updated = app_protocol.buffer_updated
+ self._app_protocol_is_buffer = True
+ else:
+ self._app_protocol_is_buffer = False
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
@@ -474,15 +370,20 @@ class SSLProtocol(protocols.Protocol):
self._waiter.set_result(None)
self._waiter = None
+ def _get_app_transport(self):
+ if self._app_transport is None:
+ if self._app_transport_created:
+ raise RuntimeError('Creating _SSLProtocolTransport twice')
+ self._app_transport = _SSLProtocolTransport(self._loop, self)
+ self._app_transport_created = True
+ return self._app_transport
+
def connection_made(self, transport):
"""Called when the low-level connection is made.
Start the SSL handshake.
"""
self._transport = transport
- self._sslpipe = _SSLPipe(self._sslcontext,
- self._server_side,
- self._server_hostname)
self._start_handshake()
def connection_lost(self, exc):
@@ -492,72 +393,58 @@ class SSLProtocol(protocols.Protocol):
meaning a regular EOF is received or the connection was
aborted or closed).
"""
- if self._session_established:
- self._session_established = False
- self._loop.call_soon(self._app_protocol.connection_lost, exc)
- else:
- # Most likely an exception occurred while in SSL handshake.
- # Just mark the app transport as closed so that its __del__
- # doesn't complain.
- if self._app_transport is not None:
- self._app_transport._closed = True
+ self._write_backlog.clear()
+ self._outgoing.read()
+ self._conn_lost += 1
+
+ # Just mark the app transport as closed so that its __dealloc__
+ # doesn't complain.
+ if self._app_transport is not None:
+ self._app_transport._closed = True
+
+ if self._state != SSLProtocolState.DO_HANDSHAKE:
+ if (
+ self._app_state == AppProtocolState.STATE_CON_MADE or
+ self._app_state == AppProtocolState.STATE_EOF
+ ):
+ self._app_state = AppProtocolState.STATE_CON_LOST
+ self._loop.call_soon(self._app_protocol.connection_lost, exc)
+ self._set_state(SSLProtocolState.UNWRAPPED)
self._transport = None
self._app_transport = None
- if getattr(self, '_handshake_timeout_handle', None):
- self._handshake_timeout_handle.cancel()
- self._wakeup_waiter(exc)
self._app_protocol = None
- self._sslpipe = None
+ self._wakeup_waiter(exc)
- def pause_writing(self):
- """Called when the low-level transport's buffer goes over
- the high-water mark.
- """
- self._app_protocol.pause_writing()
+ if self._shutdown_timeout_handle:
+ self._shutdown_timeout_handle.cancel()
+ self._shutdown_timeout_handle = None
+ if self._handshake_timeout_handle:
+ self._handshake_timeout_handle.cancel()
+ self._handshake_timeout_handle = None
- def resume_writing(self):
- """Called when the low-level transport's buffer drains below
- the low-water mark.
- """
- self._app_protocol.resume_writing()
+ def get_buffer(self, n):
+ want = n
+ if want <= 0 or want > self.max_size:
+ want = self.max_size
+ if len(self._ssl_buffer) < want:
+ self._ssl_buffer = bytearray(want)
+ self._ssl_buffer_view = memoryview(self._ssl_buffer)
+ return self._ssl_buffer_view
- def data_received(self, data):
- """Called when some SSL data is received.
+ def buffer_updated(self, nbytes):
+ self._incoming.write(self._ssl_buffer_view[:nbytes])
- The argument is a bytes object.
- """
- if self._sslpipe is None:
- # transport closing, sslpipe is destroyed
- return
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._do_handshake()
- try:
- ssldata, appdata = self._sslpipe.feed_ssldata(data)
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as e:
- self._fatal_error(e, 'SSL error in data received')
- return
+ elif self._state == SSLProtocolState.WRAPPED:
+ self._do_read()
- for chunk in ssldata:
- self._transport.write(chunk)
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_flush()
- for chunk in appdata:
- if chunk:
- try:
- if self._app_protocol_is_buffer:
- protocols._feed_data_to_buffered_proto(
- self._app_protocol, chunk)
- else:
- self._app_protocol.data_received(chunk)
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as ex:
- self._fatal_error(
- ex, 'application protocol failed to receive SSL data')
- return
- else:
- self._start_shutdown()
- break
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
def eof_received(self):
"""Called when the other end of the low-level stream
@@ -567,19 +454,32 @@ class SSLProtocol(protocols.Protocol):
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
+ self._eof_received = True
try:
if self._loop.get_debug():
logger.debug("%r received EOF", self)
- self._wakeup_waiter(ConnectionResetError)
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._on_handshake_complete(ConnectionResetError)
- if not self._in_handshake:
- keep_open = self._app_protocol.eof_received()
- if keep_open:
- logger.warning('returning true from eof_received() '
- 'has no effect when using ssl')
- finally:
+ elif self._state == SSLProtocolState.WRAPPED:
+ self._set_state(SSLProtocolState.FLUSHING)
+ if self._app_reading_paused:
+ return True
+ else:
+ self._do_flush()
+
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_write()
+ self._set_state(SSLProtocolState.SHUTDOWN)
+ self._do_shutdown()
+
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
+
+ except Exception:
self._transport.close()
+ raise
def _get_extra_info(self, name, default=None):
if name in self._extra:
@@ -589,19 +489,45 @@ class SSLProtocol(protocols.Protocol):
else:
return default
- def _start_shutdown(self):
- if self._in_shutdown:
- return
- if self._in_handshake:
- self._abort()
+ def _set_state(self, new_state):
+ allowed = False
+
+ if new_state == SSLProtocolState.UNWRAPPED:
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.UNWRAPPED and
+ new_state == SSLProtocolState.DO_HANDSHAKE
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.DO_HANDSHAKE and
+ new_state == SSLProtocolState.WRAPPED
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.WRAPPED and
+ new_state == SSLProtocolState.FLUSHING
+ ):
+ allowed = True
+
+ elif (
+ self._state == SSLProtocolState.FLUSHING and
+ new_state == SSLProtocolState.SHUTDOWN
+ ):
+ allowed = True
+
+ if allowed:
+ self._state = new_state
+
else:
- self._in_shutdown = True
- self._write_appdata(b'')
+ raise RuntimeError(
+ 'cannot switch state from {} to {}'.format(
+ self._state, new_state))
- def _write_appdata(self, data):
- self._write_backlog.append((data, 0))
- self._write_buffer_size += len(data)
- self._process_write_backlog()
+ # Handshake flow
def _start_handshake(self):
if self._loop.get_debug():
@@ -609,17 +535,18 @@ class SSLProtocol(protocols.Protocol):
self._handshake_start_time = self._loop.time()
else:
self._handshake_start_time = None
- self._in_handshake = True
- # (b'', 1) is a special value in _process_write_backlog() to do
- # the SSL handshake
- self._write_backlog.append((b'', 1))
+
+ self._set_state(SSLProtocolState.DO_HANDSHAKE)
+
+ # start handshake timeout count down
self._handshake_timeout_handle = \
self._loop.call_later(self._ssl_handshake_timeout,
- self._check_handshake_timeout)
- self._process_write_backlog()
+ lambda: self._check_handshake_timeout())
+
+ self._do_handshake()
def _check_handshake_timeout(self):
- if self._in_handshake is True:
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
msg = (
f"SSL handshake is taking longer than "
f"{self._ssl_handshake_timeout} seconds: "
@@ -627,24 +554,37 @@ class SSLProtocol(protocols.Protocol):
)
self._fatal_error(ConnectionAbortedError(msg))
+ def _do_handshake(self):
+ try:
+ self._sslobj.do_handshake()
+ except SSLAgainErrors:
+ self._process_outgoing()
+ except ssl.SSLError as exc:
+ self._on_handshake_complete(exc)
+ else:
+ self._on_handshake_complete(None)
+
def _on_handshake_complete(self, handshake_exc):
- self._in_handshake = False
- self._handshake_timeout_handle.cancel()
+ if self._handshake_timeout_handle is not None:
+ self._handshake_timeout_handle.cancel()
+ self._handshake_timeout_handle = None
- sslobj = self._sslpipe.ssl_object
+ sslobj = self._sslobj
try:
- if handshake_exc is not None:
+ if handshake_exc is None:
+ self._set_state(SSLProtocolState.WRAPPED)
+ else:
raise handshake_exc
peercert = sslobj.getpeercert()
- except (SystemExit, KeyboardInterrupt):
- raise
- except BaseException as exc:
+ except Exception as exc:
+ self._set_state(SSLProtocolState.UNWRAPPED)
if isinstance(exc, ssl.CertificateError):
msg = 'SSL handshake failed on verifying the certificate'
else:
msg = 'SSL handshake failed'
self._fatal_error(exc, msg)
+ self._wakeup_waiter(exc)
return
if self._loop.get_debug():
@@ -655,85 +595,330 @@ class SSLProtocol(protocols.Protocol):
self._extra.update(peercert=peercert,
cipher=sslobj.cipher(),
compression=sslobj.compression(),
- ssl_object=sslobj,
- )
- if self._call_connection_made:
- self._app_protocol.connection_made(self._app_transport)
+ ssl_object=sslobj)
+ if self._app_state == AppProtocolState.STATE_INIT:
+ self._app_state = AppProtocolState.STATE_CON_MADE
+ self._app_protocol.connection_made(self._get_app_transport())
self._wakeup_waiter()
- self._session_established = True
- # In case transport.write() was already called. Don't call
- # immediately _process_write_backlog(), but schedule it:
- # _on_handshake_complete() can be called indirectly from
- # _process_write_backlog(), and _process_write_backlog() is not
- # reentrant.
- self._loop.call_soon(self._process_write_backlog)
-
- def _process_write_backlog(self):
- # Try to make progress on the write backlog.
- if self._transport is None or self._sslpipe is None:
+ self._do_read()
+
+ # Shutdown flow
+
+ def _start_shutdown(self):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN,
+ SSLProtocolState.UNWRAPPED
+ )
+ ):
+ return
+ if self._app_transport is not None:
+ self._app_transport._closed = True
+ if self._state == SSLProtocolState.DO_HANDSHAKE:
+ self._abort()
+ else:
+ self._set_state(SSLProtocolState.FLUSHING)
+ self._shutdown_timeout_handle = self._loop.call_later(
+ self._ssl_shutdown_timeout,
+ lambda: self._check_shutdown_timeout()
+ )
+ self._do_flush()
+
+ def _check_shutdown_timeout(self):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN
+ )
+ ):
+ self._transport._force_close(
+ exceptions.TimeoutError('SSL shutdown timed out'))
+
+ def _do_flush(self):
+ self._do_read()
+ self._set_state(SSLProtocolState.SHUTDOWN)
+ self._do_shutdown()
+
+ def _do_shutdown(self):
+ try:
+ if not self._eof_received:
+ self._sslobj.unwrap()
+ except SSLAgainErrors:
+ self._process_outgoing()
+ except ssl.SSLError as exc:
+ self._on_shutdown_complete(exc)
+ else:
+ self._process_outgoing()
+ self._call_eof_received()
+ self._on_shutdown_complete(None)
+
+ def _on_shutdown_complete(self, shutdown_exc):
+ if self._shutdown_timeout_handle is not None:
+ self._shutdown_timeout_handle.cancel()
+ self._shutdown_timeout_handle = None
+
+ if shutdown_exc:
+ self._fatal_error(shutdown_exc)
+ else:
+ self._loop.call_soon(self._transport.close)
+
+ def _abort(self):
+ self._set_state(SSLProtocolState.UNWRAPPED)
+ if self._transport is not None:
+ self._transport.abort()
+
+ # Outgoing flow
+
+ def _write_appdata(self, list_of_data):
+ if (
+ self._state in (
+ SSLProtocolState.FLUSHING,
+ SSLProtocolState.SHUTDOWN,
+ SSLProtocolState.UNWRAPPED
+ )
+ ):
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('SSL connection is closed')
+ self._conn_lost += 1
return
+ for data in list_of_data:
+ self._write_backlog.append(data)
+ self._write_buffer_size += len(data)
+
try:
- for i in range(len(self._write_backlog)):
- data, offset = self._write_backlog[0]
- if data:
- ssldata, offset = self._sslpipe.feed_appdata(data, offset)
- elif offset:
- ssldata = self._sslpipe.do_handshake(
- self._on_handshake_complete)
- offset = 1
+ if self._state == SSLProtocolState.WRAPPED:
+ self._do_write()
+
+ except Exception as ex:
+ self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+ def _do_write(self):
+ try:
+ while self._write_backlog:
+ data = self._write_backlog[0]
+ count = self._sslobj.write(data)
+ data_len = len(data)
+ if count < data_len:
+ self._write_backlog[0] = data[count:]
+ self._write_buffer_size -= count
else:
- ssldata = self._sslpipe.shutdown(self._finalize)
- offset = 1
-
- for chunk in ssldata:
- self._transport.write(chunk)
-
- if offset < len(data):
- self._write_backlog[0] = (data, offset)
- # A short write means that a write is blocked on a read
- # We need to enable reading if it is paused!
- assert self._sslpipe.need_ssldata
- if self._transport._paused:
- self._transport.resume_reading()
- break
+ del self._write_backlog[0]
+ self._write_buffer_size -= data_len
+ except SSLAgainErrors:
+ pass
+ self._process_outgoing()
+
+ def _process_outgoing(self):
+ if not self._ssl_writing_paused:
+ data = self._outgoing.read()
+ if len(data):
+ self._transport.write(data)
+ self._control_app_writing()
+
+ # Incoming flow
+
+ def _do_read(self):
+ if (
+ self._state not in (
+ SSLProtocolState.WRAPPED,
+ SSLProtocolState.FLUSHING,
+ )
+ ):
+ return
+ try:
+ if not self._app_reading_paused:
+ if self._app_protocol_is_buffer:
+ self._do_read__buffered()
+ else:
+ self._do_read__copied()
+ if self._write_backlog:
+ self._do_write()
+ else:
+ self._process_outgoing()
+ self._control_ssl_reading()
+ except Exception as ex:
+ self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+ def _do_read__buffered(self):
+ offset = 0
+ count = 1
+
+ buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
+ wants = len(buf)
- # An entire chunk from the backlog was processed. We can
- # delete it and reduce the outstanding buffer size.
- del self._write_backlog[0]
- self._write_buffer_size -= len(data)
- except (SystemExit, KeyboardInterrupt):
+ try:
+ count = self._sslobj.read(wants, buf)
+
+ if count > 0:
+ offset = count
+ while offset < wants:
+ count = self._sslobj.read(wants - offset, buf[offset:])
+ if count > 0:
+ offset += count
+ else:
+ break
+ else:
+ self._loop.call_soon(lambda: self._do_read())
+ except SSLAgainErrors:
+ pass
+ if offset > 0:
+ self._app_protocol_buffer_updated(offset)
+ if not count:
+ # close_notify
+ self._call_eof_received()
+ self._start_shutdown()
+
+ def _do_read__copied(self):
+ chunk = b'1'
+ zero = True
+ one = False
+
+ try:
+ while True:
+ chunk = self._sslobj.read(self.max_size)
+ if not chunk:
+ break
+ if zero:
+ zero = False
+ one = True
+ first = chunk
+ elif one:
+ one = False
+ data = [first, chunk]
+ else:
+ data.append(chunk)
+ except SSLAgainErrors:
+ pass
+ if one:
+ self._app_protocol.data_received(first)
+ elif not zero:
+ self._app_protocol.data_received(b''.join(data))
+ if not chunk:
+ # close_notify
+ self._call_eof_received()
+ self._start_shutdown()
+
+ def _call_eof_received(self):
+ try:
+ if self._app_state == AppProtocolState.STATE_CON_MADE:
+ self._app_state = AppProtocolState.STATE_EOF
+ keep_open = self._app_protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ except (KeyboardInterrupt, SystemExit):
raise
- except BaseException as exc:
- if self._in_handshake:
- # Exceptions will be re-raised in _on_handshake_complete.
- self._on_handshake_complete(exc)
- else:
- self._fatal_error(exc, 'Fatal error on SSL transport')
+ except BaseException as ex:
+ self._fatal_error(ex, 'Error calling eof_received()')
+
+ # Flow control for writes from APP socket
+
+ def _control_app_writing(self):
+ size = self._get_write_buffer_size()
+ if size >= self._outgoing_high_water and not self._app_writing_paused:
+ self._app_writing_paused = True
+ try:
+ self._app_protocol.pause_writing()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.pause_writing() failed',
+ 'exception': exc,
+ 'transport': self._app_transport,
+ 'protocol': self,
+ })
+ elif size <= self._outgoing_low_water and self._app_writing_paused:
+ self._app_writing_paused = False
+ try:
+ self._app_protocol.resume_writing()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.resume_writing() failed',
+ 'exception': exc,
+ 'transport': self._app_transport,
+ 'protocol': self,
+ })
+
+ def _get_write_buffer_size(self):
+ return self._outgoing.pending + self._write_buffer_size
+
+ def _set_write_buffer_limits(self, high=None, low=None):
+ high, low = add_flowcontrol_defaults(
+ high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
+ self._outgoing_high_water = high
+ self._outgoing_low_water = low
+
+ # Flow control for reads to APP socket
+
+ def _pause_reading(self):
+ self._app_reading_paused = True
+
+ def _resume_reading(self):
+ if self._app_reading_paused:
+ self._app_reading_paused = False
+
+ def resume():
+ if self._state == SSLProtocolState.WRAPPED:
+ self._do_read()
+ elif self._state == SSLProtocolState.FLUSHING:
+ self._do_flush()
+ elif self._state == SSLProtocolState.SHUTDOWN:
+ self._do_shutdown()
+ self._loop.call_soon(resume)
+
+ # Flow control for reads from SSL socket
+
+ def _control_ssl_reading(self):
+ size = self._get_read_buffer_size()
+ if size >= self._incoming_high_water and not self._ssl_reading_paused:
+ self._ssl_reading_paused = True
+ self._transport.pause_reading()
+ elif size <= self._incoming_low_water and self._ssl_reading_paused:
+ self._ssl_reading_paused = False
+ self._transport.resume_reading()
+
+ def _set_read_buffer_limits(self, high=None, low=None):
+ high, low = add_flowcontrol_defaults(
+ high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
+ self._incoming_high_water = high
+ self._incoming_low_water = low
+
+ def _get_read_buffer_size(self):
+ return self._incoming.pending
+
+ # Flow control for writes to SSL socket
+
+ def pause_writing(self):
+ """Called when the low-level transport's buffer goes over
+ the high-water mark.
+ """
+ assert not self._ssl_writing_paused
+ self._ssl_writing_paused = True
+
+ def resume_writing(self):
+ """Called when the low-level transport's buffer drains below
+ the low-water mark.
+ """
+ assert self._ssl_writing_paused
+ self._ssl_writing_paused = False
+ self._process_outgoing()
def _fatal_error(self, exc, message='Fatal error on transport'):
+ if self._transport:
+ self._transport._force_close(exc)
+
if isinstance(exc, OSError):
if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True)
- else:
+ elif not isinstance(exc, exceptions.CancelledError):
self._loop.call_exception_handler({
'message': message,
'exception': exc,
'transport': self._transport,
'protocol': self,
})
- if self._transport:
- self._transport._force_close(exc)
-
- def _finalize(self):
- self._sslpipe = None
-
- if self._transport is not None:
- self._transport.close()
-
- def _abort(self):
- try:
- if self._transport is not None:
- self._transport.abort()
- finally:
- self._finalize()
diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py
index c88b818..cf7683f 100644
--- a/Lib/asyncio/unix_events.py
+++ b/Lib/asyncio/unix_events.py
@@ -229,7 +229,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
@@ -241,6 +242,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if ssl_handshake_timeout is not None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None:
if sock is not None:
@@ -267,13 +271,15 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
return transport, protocol
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
start_serving=True):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
@@ -282,6 +288,10 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and not ssl:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
if path is not None:
if sock is not None:
raise ValueError(
@@ -328,7 +338,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory,
- ssl, backlog, ssl_handshake_timeout)
+ ssl, backlog, ssl_handshake_timeout,
+ ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index c64e162..c6671bd 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1451,44 +1451,51 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
+ shutdown_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout)
+ ssl_handshake_timeout=handshake_timeout,
+ ssl_shutdown_timeout=shutdown_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
@@ -1869,7 +1876,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
- MyProto, sock, None, None, mock.ANY, mock.ANY)
+ MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
def test_call_coroutine(self):
async def simple_coroutine():
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index b684fab..9c46018 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -71,44 +71,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
close_transport(transport)
- @unittest.skipIf(ssl is None, 'No ssl module')
- def test_make_ssl_transport(self):
- m = mock.Mock()
- self.loop._add_reader = mock.Mock()
- self.loop._add_reader._is_coroutine = False
- self.loop._add_writer = mock.Mock()
- self.loop._remove_reader = mock.Mock()
- self.loop._remove_writer = mock.Mock()
- waiter = self.loop.create_future()
- with test_utils.disable_logger():
- transport = self.loop._make_ssl_transport(
- m, asyncio.Protocol(), m, waiter)
-
- with self.assertRaisesRegex(RuntimeError,
- r'SSL transport.*not.*initialized'):
- transport.is_reading()
-
- # execute the handshake while the logger is disabled
- # to ignore SSL handshake failure
- test_utils.run_briefly(self.loop)
-
- self.assertTrue(transport.is_reading())
- transport.pause_reading()
- transport.pause_reading()
- self.assertFalse(transport.is_reading())
- transport.resume_reading()
- transport.resume_reading()
- self.assertTrue(transport.is_reading())
-
- # Sanity check
- class_name = transport.__class__.__name__
- self.assertIn("ssl", class_name.lower())
- self.assertIn("transport", class_name.lower())
-
- transport.close()
- # execute pending callbacks to close the socket transport
- test_utils.run_briefly(self.loop)
-
@mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
new file mode 100644
index 0000000..8d1bb03
--- /dev/null
+++ b/Lib/test/test_asyncio/test_ssl.py
@@ -0,0 +1,1721 @@
+import asyncio
+import asyncio.sslproto
+import contextlib
+import gc
+import logging
+import select
+import socket
+import tempfile
+import threading
+import time
+import weakref
+import unittest
+
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from test import support
+from test.test_asyncio import utils as test_utils
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class MyBaseProto(asyncio.Protocol):
+ connected = None
+ done = None
+
+ def __init__(self, loop=None):
+ self.transport = None
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.connected = asyncio.Future(loop=loop)
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ if self.connected:
+ self.connected.set_result(None)
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class TestSSL(test_utils.TestCase):
+
+ PAYLOAD_SIZE = 1024 * 100
+ TIMEOUT = 60
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop)
+ self.addCleanup(self.loop.close)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ if not self.loop.is_closed():
+ test_utils.run_briefly(self.loop)
+
+ self.doCleanups()
+ support.gc_collect()
+ super().tearDown()
+
+ def tcp_server(self, server_prog, *,
+ family=socket.AF_INET,
+ addr=None,
+ timeout=5,
+ backlog=1,
+ max_clients=10):
+
+ if addr is None:
+ if family == getattr(socket, "AF_UNIX", None):
+ with tempfile.NamedTemporaryFile() as tmp:
+ addr = tmp.name
+ else:
+ addr = ('127.0.0.1', 0)
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ try:
+ sock.bind(addr)
+ sock.listen(backlog)
+ except OSError as ex:
+ sock.close()
+ raise ex
+
+ return TestThreadedServer(
+ self, sock, server_prog, timeout, max_clients)
+
+ def tcp_client(self, client_prog,
+ family=socket.AF_INET,
+ timeout=10):
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ return TestThreadedClient(
+ self, sock, client_prog, timeout)
+
+ def unix_server(self, *args, **kwargs):
+ return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
+
+ def unix_client(self, *args, **kwargs):
+ return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
+
+ def _create_server_ssl_context(self, certfile, keyfile=None):
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.load_cert_chain(certfile, keyfile)
+ return sslcontext
+
+ def _create_client_ssl_context(self, *, disable_verify=True):
+ sslcontext = ssl.create_default_context()
+ sslcontext.check_hostname = False
+ if disable_verify:
+ sslcontext.verify_mode = ssl.CERT_NONE
+ return sslcontext
+
+ @contextlib.contextmanager
+ def _silence_eof_received_warning(self):
+ # TODO This warning has to be fixed in asyncio.
+ logger = logging.getLogger('asyncio')
+ filter = logging.Filter('has no effect when using ssl')
+ logger.addFilter(filter)
+ try:
+ yield
+ finally:
+ logger.removeFilter(filter)
+
+ def _abort_socket_test(self, ex):
+ try:
+ self.loop.stop()
+ finally:
+ self.fail(ex)
+
+ def new_loop(self):
+ return asyncio.new_event_loop()
+
+ def new_policy(self):
+ return asyncio.DefaultEventLoopPolicy()
+
+ async def wait_closed(self, obj):
+ if not isinstance(obj, asyncio.StreamWriter):
+ return
+ try:
+ await obj.wait_closed()
+ except (BrokenPipeError, ConnectionError):
+ pass
+
+ def test_create_server_ssl_1(self):
+ CNT = 0 # number of clients that were successful
+ TOTAL_CNT = 25 # total number of clients that test will create
+ TIMEOUT = 60.0 # timeout for this test
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+
+ clients = []
+
+ async def handle_client(reader, writer):
+ nonlocal CNT
+
+ data = await reader.readexactly(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ writer.write(b'OK')
+
+ data = await reader.readexactly(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
+
+ await writer.drain()
+ writer.close()
+
+ CNT += 1
+
+ async def test_client(addr):
+ fut = asyncio.Future()
+
+ def prog(sock):
+ try:
+ sock.starttls(client_sslctx)
+ sock.connect(addr)
+ sock.send(A_DATA)
+
+ data = sock.recv_all(2)
+ self.assertEqual(data, b'OK')
+
+ sock.send(B_DATA)
+ data = sock.recv_all(4)
+ self.assertEqual(data, b'SPAM')
+
+ sock.close()
+
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(fut.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(fut.set_result, None)
+
+ client = self.tcp_client(prog)
+ client.start()
+ clients.append(client)
+
+ await fut
+
+ async def start_server():
+ extras = {}
+ extras = dict(ssl_handshake_timeout=40.0)
+
+ srv = await asyncio.start_server(
+ handle_client,
+ '127.0.0.1', 0,
+ family=socket.AF_INET,
+ ssl=sslctx,
+ **extras)
+
+ try:
+ srv_socks = srv.sockets
+ self.assertTrue(srv_socks)
+
+ addr = srv_socks[0].getsockname()
+
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(test_client(addr))
+
+ await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
+
+ finally:
+ self.loop.call_soon(srv.close)
+ await srv.wait_closed()
+
+ with self._silence_eof_received_warning():
+ self.loop.run_until_complete(start_server())
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ for client in clients:
+ client.stop()
+
+ def test_create_connection_ssl_1(self):
+ self.loop.set_exception_handler(None)
+
+ CNT = 0
+ TOTAL_CNT = 25
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+
+ def server(sock):
+ sock.starttls(
+ sslctx,
+ server_side=True)
+
+ data = sock.recv_all(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ sock.send(b'OK')
+
+ data = sock.recv_all(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ sock.send(b'SPAM')
+
+ sock.close()
+
+ async def client(addr):
+ extras = {}
+ extras = dict(ssl_handshake_timeout=40.0)
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ **extras)
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ writer.write(B_DATA)
+ self.assertEqual(await reader.readexactly(4), b'SPAM')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ async def client_sock(addr):
+ sock = socket.socket()
+ sock.connect(addr)
+ reader, writer = await asyncio.open_connection(
+ sock=sock,
+ ssl=client_sslctx,
+ server_hostname='')
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ writer.write(B_DATA)
+ self.assertEqual(await reader.readexactly(4), b'SPAM')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+ sock.close()
+
+ def run(coro):
+ nonlocal CNT
+ CNT = 0
+
+ async def _gather(*tasks):
+ # trampoline
+ return await asyncio.gather(*tasks)
+
+ with self.tcp_server(server,
+ max_clients=TOTAL_CNT,
+ backlog=TOTAL_CNT) as srv:
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(coro(srv.addr))
+
+ self.loop.run_until_complete(_gather(*tasks))
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ with self._silence_eof_received_warning():
+ run(client)
+
+ with self._silence_eof_received_warning():
+ run(client_sock)
+
+ def test_create_connection_ssl_slow_handshake(self):
+ client_sslctx = self._create_client_ssl_context()
+
+ # silence error logger
+ self.loop.set_exception_handler(lambda *args: None)
+
+ def server(sock):
+ try:
+ sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ pass
+ finally:
+ sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=1.0)
+ writer.close()
+ await self.wait_closed(writer)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaisesRegex(
+ ConnectionAbortedError,
+ r'SSL handshake.*is taking longer'):
+
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_create_connection_ssl_failed_certificate(self):
+ # silence error logger
+ self.loop.set_exception_handler(lambda *args: None)
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context(disable_verify=False)
+
+ def server(sock):
+ try:
+ sock.starttls(
+ sslctx,
+ server_side=True)
+ sock.connect()
+ except (ssl.SSLError, OSError):
+ pass
+ finally:
+ sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=1.0)
+ writer.close()
+ await self.wait_closed(writer)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(ssl.SSLCertVerificationError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_ssl_handshake_timeout(self):
+ # bpo-29970: Check that a connection is aborted if handshake is not
+ # completed in timeout period, instead of remaining open indefinitely
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ # silence error logger
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server_side_aborted = False
+
+ def server(sock):
+ nonlocal server_side_aborted
+ try:
+ sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ server_side_aborted = True
+ finally:
+ sock.close()
+
+ async def client(addr):
+ await asyncio.wait_for(
+ self.loop.create_connection(
+ asyncio.Protocol,
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=10.0),
+ 0.5)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ self.assertTrue(server_side_aborted)
+
+ # Python issue #23197: cancelling a handshake must not raise an
+ # exception or log an error, even if the handshake failed
+ self.assertEqual(messages, [])
+
+ def test_ssl_handshake_connection_lost(self):
+ # #246: make sure that no connection_lost() is called before
+ # connection_made() is called first
+
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ # silence error logger
+ self.loop.set_exception_handler(lambda loop, ctx: None)
+
+ connection_made_called = False
+ connection_lost_called = False
+
+ def server(sock):
+ sock.recv(1024)
+ # break the connection during handshake
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def connection_made(self, transport):
+ nonlocal connection_made_called
+ connection_made_called = True
+
+ def connection_lost(self, exc):
+ nonlocal connection_lost_called
+ connection_lost_called = True
+
+ async def client(addr):
+ await self.loop.create_connection(
+ ClientProto,
+ *addr,
+ ssl=client_sslctx,
+ server_hostname=''),
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(ConnectionResetError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ if connection_lost_called:
+ if connection_made_called:
+ self.fail("unexpected call to connection_lost()")
+ else:
+ self.fail("unexpected call to connection_lost() without"
+ "calling connection_made()")
+ elif connection_made_called:
+ self.fail("unexpected call to connection_made()")
+
+ def test_ssl_connect_accepted_socket(self):
+ proto = ssl.PROTOCOL_TLS_SERVER
+ server_context = ssl.SSLContext(proto)
+ server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
+ if hasattr(server_context, 'check_hostname'):
+ server_context.check_hostname = False
+ server_context.verify_mode = ssl.CERT_NONE
+
+ client_context = ssl.SSLContext(proto)
+ if hasattr(server_context, 'check_hostname'):
+ client_context.check_hostname = False
+ client_context.verify_mode = ssl.CERT_NONE
+
+ def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
+ loop = self.loop
+
+ class MyProto(MyBaseProto):
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ loop.call_soon(loop.stop)
+
+ def data_received(self, data):
+ super().data_received(data)
+ self.transport.write(expected_response)
+
+ lsock = socket.socket(socket.AF_INET)
+ lsock.bind(('127.0.0.1', 0))
+ lsock.listen(1)
+ addr = lsock.getsockname()
+
+ message = b'test data'
+ response = None
+ expected_response = b'roger'
+
+ def client():
+ nonlocal response
+ try:
+ csock = socket.socket(socket.AF_INET)
+ if client_ssl is not None:
+ csock = client_ssl.wrap_socket(csock)
+ csock.connect(addr)
+ csock.sendall(message)
+ response = csock.recv(99)
+ csock.close()
+ except Exception as exc:
+ print(
+ "Failure in client thread in test_connect_accepted_socket",
+ exc)
+
+ thread = threading.Thread(target=client, daemon=True)
+ thread.start()
+
+ conn, _ = lsock.accept()
+ proto = MyProto(loop=loop)
+ proto.loop = loop
+
+ extras = {}
+ if server_ssl:
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ f = loop.create_task(
+ loop.connect_accepted_socket(
+ (lambda: proto), conn, ssl=server_ssl,
+ **extras))
+ loop.run_forever()
+ conn.close()
+ lsock.close()
+
+ thread.join(1)
+ self.assertFalse(thread.is_alive())
+ self.assertEqual(proto.state, 'CLOSED')
+ self.assertEqual(proto.nbytes, len(message))
+ self.assertEqual(response, expected_response)
+ tr, _ = f.result()
+
+ if server_ssl:
+ self.assertIn('SSL', tr.__class__.__name__)
+
+ tr.close()
+ # let it close
+ self.loop.run_until_complete(asyncio.sleep(0.1))
+
+ def test_start_tls_client_corrupted_ssl(self):
+ self.loop.set_exception_handler(lambda loop, ctx: None)
+
+ sslctx = test_utils.simple_server_sslcontext()
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ def server(sock):
+ orig_sock = sock.dup()
+ try:
+ sock.starttls(
+ sslctx,
+ server_side=True)
+ sock.sendall(b'A\n')
+ sock.recv_all(1)
+ orig_sock.send(b'please corrupt the SSL connection')
+ except ssl.SSLError:
+ pass
+ finally:
+ sock.close()
+ orig_sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+
+ self.assertEqual(await reader.readline(), b'A\n')
+ writer.write(b'B')
+ with self.assertRaises(ssl.SSLError):
+ await reader.readline()
+ writer.close()
+ try:
+ await self.wait_closed(writer)
+ except ssl.SSLError:
+ pass
+ return 'OK'
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ res = self.loop.run_until_complete(client(srv.addr))
+
+ self.assertEqual(res, 'OK')
+
+ def test_start_tls_client_reg_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data, b'O')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ def test_create_connection_memory_leak(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_context = self._create_client_ssl_context()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ # XXX: We assume user stores the transport in protocol
+ proto.tr = tr
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr,
+ ssl=client_context)
+
+ self.assertEqual(await on_data, b'O')
+ tr.write(HELLO_MSG)
+ await on_eof
+
+ tr.close()
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ # No garbage is left for SSL client from loop.create_connection, even
+ # if user stores the SSLTransport in corresponding protocol instance
+ client_context = weakref.ref(client_context)
+ self.assertIsNone(client_context())
+
+ def test_start_tls_client_buf_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ client_con_made_calls = 0
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
+ self.on_data = on_data
+ self.buf = bytearray(1)
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def get_buffer(self, sizehint):
+ return self.buf
+
+ def buffer_updated(self, nsize):
+ assert nsize == 1
+ self.on_data.set_result(bytes(self.buf[:nsize]))
+
+ def eof_received(self):
+ pass
+
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProtoFirst(on_data1), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr),
+ timeout=self.TIMEOUT))
+
+ def test_start_tls_slow_client_cancel(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ client_context = test_utils.simple_client_sslcontext()
+ server_waits_on_handshake = self.loop.create_future()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ try:
+ self.loop.call_soon_threadsafe(
+ server_waits_on_handshake.set_result, None)
+ data = sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ pass
+ finally:
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+
+ await server_waits_on_handshake
+
+ with self.assertRaises(asyncio.TimeoutError):
+ await asyncio.wait_for(
+ self.loop.start_tls(tr, proto, client_context),
+ 0.5)
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ def test_start_tls_server_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def client(sock, addr):
+ sock.settimeout(self.TIMEOUT)
+
+ sock.connect(addr)
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(client_context)
+ sock.sendall(HELLO_MSG)
+
+ sock.unwrap()
+ sock.close()
+
+ class ServerProto(asyncio.Protocol):
+ def __init__(self, on_con, on_eof, on_con_lost):
+ self.on_con = on_con
+ self.on_eof = on_eof
+ self.on_con_lost = on_con_lost
+ self.data = b''
+
+ def connection_made(self, tr):
+ self.on_con.set_result(tr)
+
+ def data_received(self, data):
+ self.data += data
+
+ def eof_received(self):
+ self.on_eof.set_result(1)
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self.on_con_lost.set_result(None)
+ else:
+ self.on_con_lost.set_exception(exc)
+
+ async def main(proto, on_con, on_eof, on_con_lost):
+ tr = await on_con
+ tr.write(HELLO_MSG)
+
+ self.assertEqual(proto.data, b'')
+
+ new_tr = await self.loop.start_tls(
+ tr, proto, server_context,
+ server_side=True,
+ ssl_handshake_timeout=self.TIMEOUT)
+
+ await on_eof
+ await on_con_lost
+ self.assertEqual(proto.data, HELLO_MSG)
+ new_tr.close()
+
+ async def run_main():
+ on_con = self.loop.create_future()
+ on_eof = self.loop.create_future()
+ on_con_lost = self.loop.create_future()
+ proto = ServerProto(on_con, on_eof, on_con_lost)
+
+ server = await self.loop.create_server(
+ lambda: proto, '127.0.0.1', 0)
+ addr = server.sockets[0].getsockname()
+
+ with self.tcp_client(lambda sock: client(sock, addr),
+ timeout=self.TIMEOUT):
+ await asyncio.wait_for(
+ main(proto, on_con, on_eof, on_con_lost),
+ timeout=self.TIMEOUT)
+
+ server.close()
+ await server.wait_closed()
+
+ self.loop.run_until_complete(run_main())
+
+ def test_create_server_ssl_over_ssl(self):
+ CNT = 0 # number of clients that were successful
+ TOTAL_CNT = 25 # total number of clients that test will create
+ TIMEOUT = 10.0 # timeout for this test
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx_1 = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx_1 = self._create_client_ssl_context()
+ sslctx_2 = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx_2 = self._create_client_ssl_context()
+
+ clients = []
+
+ async def handle_client(reader, writer):
+ nonlocal CNT
+
+ data = await reader.readexactly(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ writer.write(b'OK')
+
+ data = await reader.readexactly(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
+
+ await writer.drain()
+ writer.close()
+
+ CNT += 1
+
+ class ServerProtocol(asyncio.StreamReaderProtocol):
+ def connection_made(self, transport):
+ super_ = super()
+ transport.pause_reading()
+ fut = self._loop.create_task(self._loop.start_tls(
+ transport, self, sslctx_2, server_side=True))
+
+ def cb(_):
+ try:
+ tr = fut.result()
+ except Exception as ex:
+ super_.connection_lost(ex)
+ else:
+ super_.connection_made(tr)
+ fut.add_done_callback(cb)
+
+ def server_protocol_factory():
+ reader = asyncio.StreamReader()
+ protocol = ServerProtocol(reader, handle_client)
+ return protocol
+
+ async def test_client(addr):
+ fut = asyncio.Future()
+
+ def prog(sock):
+ try:
+ sock.connect(addr)
+ sock.starttls(client_sslctx_1)
+
+ # because wrap_socket() doesn't work correctly on
+ # SSLSocket, we have to do the 2nd level SSL manually
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
+
+ def do(func, *args):
+ while True:
+ try:
+ rv = func(*args)
+ break
+ except ssl.SSLWantReadError:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ incoming.write(sock.recv(65536))
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ return rv
+
+ do(sslobj.do_handshake)
+
+ do(sslobj.write, A_DATA)
+ data = do(sslobj.read, 2)
+ self.assertEqual(data, b'OK')
+
+ do(sslobj.write, B_DATA)
+ data = b''
+ while True:
+ chunk = do(sslobj.read, 4)
+ if not chunk:
+ break
+ data += chunk
+ self.assertEqual(data, b'SPAM')
+
+ do(sslobj.unwrap)
+ sock.close()
+
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(fut.set_exception, ex)
+ sock.close()
+ else:
+ self.loop.call_soon_threadsafe(fut.set_result, None)
+
+ client = self.tcp_client(prog)
+ client.start()
+ clients.append(client)
+
+ await fut
+
+ async def start_server():
+ extras = {}
+
+ srv = await self.loop.create_server(
+ server_protocol_factory,
+ '127.0.0.1', 0,
+ family=socket.AF_INET,
+ ssl=sslctx_1,
+ **extras)
+
+ try:
+ srv_socks = srv.sockets
+ self.assertTrue(srv_socks)
+
+ addr = srv_socks[0].getsockname()
+
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(test_client(addr))
+
+ await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
+
+ finally:
+ self.loop.call_soon(srv.close)
+ await srv.wait_closed()
+
+ with self._silence_eof_received_warning():
+ self.loop.run_until_complete(start_server())
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ for client in clients:
+ client.stop()
+
+ def test_shutdown_cleanly(self):
+ CNT = 0
+ TOTAL_CNT = 25
+
+ A_DATA = b'A' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx = self._create_client_ssl_context()
+
+ def server(sock):
+ sock.starttls(
+ sslctx,
+ server_side=True)
+
+ data = sock.recv_all(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ sock.send(b'OK')
+
+ sock.unwrap()
+
+ sock.close()
+
+ async def client(addr):
+ extras = {}
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ **extras)
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ self.assertEqual(await reader.read(), b'')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ def run(coro):
+ nonlocal CNT
+ CNT = 0
+
+ async def _gather(*tasks):
+ return await asyncio.gather(*tasks)
+
+ with self.tcp_server(server,
+ max_clients=TOTAL_CNT,
+ backlog=TOTAL_CNT) as srv:
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(coro(srv.addr))
+
+ self.loop.run_until_complete(
+ _gather(*tasks))
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ with self._silence_eof_received_warning():
+ run(client)
+
+ def test_flush_before_shutdown(self):
+ CHUNK = 1024 * 128
+ SIZE = 32
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx = self._create_client_ssl_context()
+
+ future = None
+
+ def server(sock):
+ sock.starttls(sslctx, server_side=True)
+ self.assertEqual(sock.recv_all(4), b'ping')
+ sock.send(b'pong')
+ time.sleep(0.5) # hopefully stuck the TCP buffer
+ data = sock.recv_all(CHUNK * SIZE)
+ self.assertEqual(len(data), CHUNK * SIZE)
+ sock.close()
+
+ def run(meth):
+ def wrapper(sock):
+ try:
+ meth(sock)
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(future.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(future.set_result, None)
+ return wrapper
+
+ async def client(addr):
+ nonlocal future
+ future = self.loop.create_future()
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+ sslprotocol = writer.transport._ssl_protocol
+ writer.write(b'ping')
+ data = await reader.readexactly(4)
+ self.assertEqual(data, b'pong')
+
+ sslprotocol.pause_writing()
+ for _ in range(SIZE):
+ writer.write(b'x' * CHUNK)
+
+ writer.close()
+ sslprotocol.resume_writing()
+
+ await self.wait_closed(writer)
+ try:
+ data = await reader.read()
+ self.assertEqual(data, b'')
+ except ConnectionResetError:
+ pass
+ await future
+
+ with self.tcp_server(run(server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_remote_shutdown_receives_trailing_data(self):
+ CHUNK = 1024 * 128
+ SIZE = 32
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+ future = None
+
+ def server(sock):
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
+
+ while True:
+ try:
+ sslobj.do_handshake()
+ except ssl.SSLWantReadError:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ incoming.write(sock.recv(16384))
+ else:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ break
+
+ while True:
+ try:
+ data = sslobj.read(4)
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ else:
+ break
+
+ self.assertEqual(data, b'ping')
+ sslobj.write(b'pong')
+ sock.send(outgoing.read())
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send close_notify but don't wait for response
+ with self.assertRaises(ssl.SSLWantReadError):
+ sslobj.unwrap()
+ sock.send(outgoing.read())
+
+ # should receive all data
+ data_len = 0
+ while True:
+ try:
+ chunk = len(sslobj.read(16384))
+ data_len += chunk
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ except ssl.SSLZeroReturnError:
+ break
+
+ self.assertEqual(data_len, CHUNK * SIZE)
+
+ # verify that close_notify is received
+ sslobj.unwrap()
+
+ sock.close()
+
+ def eof_server(sock):
+ sock.starttls(sslctx, server_side=True)
+ self.assertEqual(sock.recv_all(4), b'ping')
+ sock.send(b'pong')
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send EOF
+ sock.shutdown(socket.SHUT_WR)
+
+ # should receive all data
+ data = sock.recv_all(CHUNK * SIZE)
+ self.assertEqual(len(data), CHUNK * SIZE)
+
+ sock.close()
+
+ async def client(addr):
+ nonlocal future
+ future = self.loop.create_future()
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+ writer.write(b'ping')
+ data = await reader.readexactly(4)
+ self.assertEqual(data, b'pong')
+
+ # fill write backlog in a hacky way - renegotiation won't help
+ for _ in range(SIZE):
+ writer.transport._test__append_write_backlog(b'x' * CHUNK)
+
+ try:
+ data = await reader.read()
+ self.assertEqual(data, b'')
+ except (BrokenPipeError, ConnectionResetError):
+ pass
+
+ await future
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ def run(meth):
+ def wrapper(sock):
+ try:
+ meth(sock)
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(future.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(future.set_result, None)
+ return wrapper
+
+ with self.tcp_server(run(server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ with self.tcp_server(run(eof_server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_connect_timeout_warning(self):
+ s = socket.socket(socket.AF_INET)
+ s.bind(('127.0.0.1', 0))
+ addr = s.getsockname()
+
+ async def test():
+ try:
+ await asyncio.wait_for(
+ self.loop.create_connection(asyncio.Protocol,
+ *addr, ssl=True),
+ 0.1)
+ except (ConnectionRefusedError, asyncio.TimeoutError):
+ pass
+ else:
+ self.fail('TimeoutError is not raised')
+
+ with s:
+ try:
+ with self.assertWarns(ResourceWarning) as cm:
+ self.loop.run_until_complete(test())
+ gc.collect()
+ gc.collect()
+ gc.collect()
+ except AssertionError as e:
+ self.assertEqual(str(e), 'ResourceWarning not triggered')
+ else:
+ self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
+
+ def test_handshake_timeout_handler_leak(self):
+ s = socket.socket(socket.AF_INET)
+ s.bind(('127.0.0.1', 0))
+ s.listen(1)
+ addr = s.getsockname()
+
+ async def test(ctx):
+ try:
+ await asyncio.wait_for(
+ self.loop.create_connection(asyncio.Protocol, *addr,
+ ssl=ctx),
+ 0.1)
+ except (ConnectionRefusedError, asyncio.TimeoutError):
+ pass
+ else:
+ self.fail('TimeoutError is not raised')
+
+ with s:
+ ctx = ssl.create_default_context()
+ self.loop.run_until_complete(test(ctx))
+ ctx = weakref.ref(ctx)
+
+ # SSLProtocol should be DECREF to 0
+ self.assertIsNone(ctx())
+
+ def test_shutdown_timeout_handler_leak(self):
+ loop = self.loop
+
+ def server(sock):
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ sock = sslctx.wrap_socket(sock, server_side=True)
+ sock.recv(32)
+ sock.close()
+
+ class Protocol(asyncio.Protocol):
+ def __init__(self):
+ self.fut = asyncio.Future(loop=loop)
+
+ def connection_lost(self, exc):
+ self.fut.set_result(None)
+
+ async def client(addr, ctx):
+ tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
+ tr.close()
+ await pr.fut
+
+ with self.tcp_server(server) as srv:
+ ctx = self._create_client_ssl_context()
+ loop.run_until_complete(client(srv.addr, ctx))
+ ctx = weakref.ref(ctx)
+
+ # asyncio has no shutdown timeout, but it ends up with a circular
+ # reference loop - not ideal (introduces gc glitches), but at least
+ # not leaking
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ # SSLProtocol should be DECREF to 0
+ self.assertIsNone(ctx())
+
+ def test_shutdown_timeout_handler_not_set(self):
+ loop = self.loop
+ eof = asyncio.Event()
+ extra = None
+
+ def server(sock):
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ sock = sslctx.wrap_socket(sock, server_side=True)
+ sock.send(b'hello')
+ assert sock.recv(1024) == b'world'
+ sock.send(b'extra bytes')
+ # sending EOF here
+ sock.shutdown(socket.SHUT_WR)
+ loop.call_soon_threadsafe(eof.set)
+ # make sure we have enough time to reproduce the issue
+ assert sock.recv(1024) == b''
+ sock.close()
+
+ class Protocol(asyncio.Protocol):
+ def __init__(self):
+ self.fut = asyncio.Future(loop=loop)
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.transport = transport
+
+ def data_received(self, data):
+ if data == b'hello':
+ self.transport.write(b'world')
+ # pause reading would make incoming data stay in the sslobj
+ self.transport.pause_reading()
+ else:
+ nonlocal extra
+ extra = data
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self.fut.set_result(None)
+ else:
+ self.fut.set_exception(exc)
+
+ async def client(addr):
+ ctx = self._create_client_ssl_context()
+ tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
+ await eof.wait()
+ tr.resume_reading()
+ await pr.fut
+ tr.close()
+ assert extra == b'extra bytes'
+
+ with self.tcp_server(server) as srv:
+ loop.run_until_complete(client(srv.addr))
+
+
+###############################################################################
+# Socket Testing Utilities
+###############################################################################
+
+
+class TestSocketWrapper:
+
+ def __init__(self, sock):
+ self.__sock = sock
+
+ def recv_all(self, n):
+ buf = b''
+ while len(buf) < n:
+ data = self.recv(n - len(buf))
+ if data == b'':
+ raise ConnectionAbortedError
+ buf += data
+ return buf
+
+ def starttls(self, ssl_context, *,
+ server_side=False,
+ server_hostname=None,
+ do_handshake_on_connect=True):
+
+ assert isinstance(ssl_context, ssl.SSLContext)
+
+ ssl_sock = ssl_context.wrap_socket(
+ self.__sock, server_side=server_side,
+ server_hostname=server_hostname,
+ do_handshake_on_connect=do_handshake_on_connect)
+
+ if server_side:
+ ssl_sock.do_handshake()
+
+ self.__sock.close()
+ self.__sock = ssl_sock
+
+ def __getattr__(self, name):
+ return getattr(self.__sock, name)
+
+ def __repr__(self):
+ return '<{} {!r}>'.format(type(self).__name__, self.__sock)
+
+
+class SocketThread(threading.Thread):
+
+ def stop(self):
+ self._active = False
+ self.join()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *exc):
+ self.stop()
+
+
+class TestThreadedClient(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout):
+ threading.Thread.__init__(self, None, None, 'test-client')
+ self.daemon = True
+
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+ self._prog = prog
+ self._test = test
+
+ def run(self):
+ try:
+ self._prog(TestSocketWrapper(self._sock))
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._test._abort_socket_test(ex)
+
+
+class TestThreadedServer(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout, max_clients):
+ threading.Thread.__init__(self, None, None, 'test-server')
+ self.daemon = True
+
+ self._clients = 0
+ self._finished_clients = 0
+ self._max_clients = max_clients
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+
+ self._prog = prog
+
+ self._s1, self._s2 = socket.socketpair()
+ self._s1.setblocking(False)
+
+ self._test = test
+
+ def stop(self):
+ try:
+ if self._s2 and self._s2.fileno() != -1:
+ try:
+ self._s2.send(b'stop')
+ except OSError:
+ pass
+ finally:
+ super().stop()
+
+ def run(self):
+ try:
+ with self._sock:
+ self._sock.setblocking(0)
+ self._run()
+ finally:
+ self._s1.close()
+ self._s2.close()
+
+ def _run(self):
+ while self._active:
+ if self._clients >= self._max_clients:
+ return
+
+ r, w, x = select.select(
+ [self._sock, self._s1], [], [], self._timeout)
+
+ if self._s1 in r:
+ return
+
+ if self._sock in r:
+ try:
+ conn, addr = self._sock.accept()
+ except BlockingIOError:
+ continue
+ except socket.timeout:
+ if not self._active:
+ return
+ else:
+ raise
+ else:
+ self._clients += 1
+ conn.settimeout(self._timeout)
+ try:
+ with conn:
+ self._handle_client(conn)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._active = False
+ try:
+ raise
+ finally:
+ self._test._abort_socket_test(ex)
+
+ def _handle_client(self, sock):
+ self._prog(TestSocketWrapper(sock))
+
+ @property
+ def addr(self):
+ return self._sock.getsockname()
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 22a216a..6e8de7c 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -15,7 +15,6 @@ import asyncio
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
-from test import support
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
@@ -44,16 +43,13 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
- sslpipe = mock.Mock()
- sslpipe.shutdown.return_value = b''
- if do_handshake:
- sslpipe.do_handshake.side_effect = do_handshake
- else:
- def mock_handshake(callback):
- return []
- sslpipe.do_handshake.side_effect = mock_handshake
- with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
- ssl_proto.connection_made(transport)
+ sslobj = mock.Mock()
+ # emulate reading decompressed data
+ sslobj.read.side_effect = ssl.SSLWantReadError
+ if do_handshake is not None:
+ sslobj.do_handshake = do_handshake
+ ssl_proto._sslobj = sslobj
+ ssl_proto.connection_made(transport)
return transport
def test_handshake_timeout_zero(self):
@@ -75,7 +71,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(ssl_proto)
+ self.connection_made(
+ ssl_proto,
+ do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
+ )
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
@@ -100,7 +99,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(ssl_proto)
+ self.connection_made(
+ ssl_proto,
+ do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
+ )
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@@ -110,7 +112,10 @@ class SslProtoHandshakeTests(test_utils.TestCase):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- transport = self.connection_made(ssl_proto)
+ transport = self.connection_made(
+ ssl_proto,
+ do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
+ )
test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close()
@@ -143,7 +148,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
transp.close()
# should not raise
- self.assertIsNone(ssl_proto.data_received(b'data'))
+ self.assertIsNone(ssl_proto.buffer_updated(5))
def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()
diff --git a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst
new file mode 100644
index 0000000..1a48aa5
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst
@@ -0,0 +1,2 @@
+Reimplement SSL/TLS support in asyncio, borrow the implementation from
+uvloop library.