summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPablo Galindo <Pablogsal@gmail.com>2021-05-03 15:21:59 (GMT)
committerGitHub <noreply@github.com>2021-05-03 15:21:59 (GMT)
commit7719953b30430b351ba0f153c2b51b16cc68ee36 (patch)
tree8014086b85a13ed79d45e29ab74a9a9f5c9c68eb
parent39494285e15dc2d291ec13de5045b930eaf0a3db (diff)
downloadcpython-7719953b30430b351ba0f153c2b51b16cc68ee36.zip
cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.gz
cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.bz2
bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)
This reverts commit 5fb06edbbb769561e245d0fe13002bab50e2ae60 and all subsequent dependent commits.
-rw-r--r--Lib/asyncio/base_events.py43
-rw-r--r--Lib/asyncio/constants.py7
-rw-r--r--Lib/asyncio/events.py19
-rw-r--r--Lib/asyncio/proactor_events.py12
-rw-r--r--Lib/asyncio/selector_events.py31
-rw-r--r--Lib/asyncio/sslproto.py1059
-rw-r--r--Lib/asyncio/unix_events.py17
-rw-r--r--Lib/test/test_asyncio/test_base_events.py21
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py38
-rw-r--r--Lib/test/test_asyncio/test_ssl.py1723
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py35
-rw-r--r--Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst2
12 files changed, 527 insertions, 2480 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index e54ee30..f789635 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -273,7 +273,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):
class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
- ssl_handshake_timeout, ssl_shutdown_timeout=None):
+ ssl_handshake_timeout):
self._loop = loop
self._sockets = sockets
self._active_count = 0
@@ -282,7 +282,6 @@ class Server(events.AbstractServer):
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
- self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
@@ -314,8 +313,7 @@ class Server(events.AbstractServer):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
- self, self._backlog, self._ssl_handshake_timeout,
- self._ssl_shutdown_timeout)
+ self, self._backlog, self._ssl_handshake_timeout)
def get_loop(self):
return self._loop
@@ -469,7 +467,6 @@ class BaseEventLoop(events.AbstractEventLoop):
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
@@ -972,7 +969,6 @@ class BaseEventLoop(events.AbstractEventLoop):
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
@@ -1008,10 +1004,6 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
- if ssl_shutdown_timeout is not None and not ssl:
- raise ValueError(
- 'ssl_shutdown_timeout is only meaningful with ssl')
-
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
@@ -1087,8 +1079,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
@@ -1100,8 +1091,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
sock.setblocking(False)
@@ -1112,8 +1102,7 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
@@ -1204,8 +1193,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
@@ -1228,7 +1216,6 @@ class BaseEventLoop(events.AbstractEventLoop):
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)
# Pause early so that "ssl_protocol.data_received()" doesn't
@@ -1427,7 +1414,6 @@ class BaseEventLoop(events.AbstractEventLoop):
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
@@ -1451,10 +1437,6 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
- if ssl_shutdown_timeout is not None and ssl is None:
- raise ValueError(
- 'ssl_shutdown_timeout is only meaningful with ssl')
-
if host is not None or port is not None:
if sock is not None:
raise ValueError(
@@ -1527,8 +1509,7 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
server = Server(self, sockets, protocol_factory,
- ssl, backlog, ssl_handshake_timeout,
- ssl_shutdown_timeout)
+ ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
@@ -1542,8 +1523,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
@@ -1551,14 +1531,9 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
- if ssl_shutdown_timeout is not None and not ssl:
- raise ValueError(
- 'ssl_shutdown_timeout is only meaningful with ssl')
-
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py
index f171ead..33feed6 100644
--- a/Lib/asyncio/constants.py
+++ b/Lib/asyncio/constants.py
@@ -15,17 +15,10 @@ DEBUG_STACK_DEPTH = 10
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0
-# Number of seconds to wait for SSL shutdown to complete
-# The default timeout mimics lingering_time
-SSL_SHUTDOWN_TIMEOUT = 30.0
-
# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256
-FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
-FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB
-
# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index d5254fa..b966ad2 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -304,7 +304,6 @@ class AbstractEventLoop:
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError
@@ -314,7 +313,6 @@ class AbstractEventLoop:
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.
@@ -355,10 +353,6 @@ class AbstractEventLoop:
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.
- ssl_shutdown_timeout is the time in seconds that an SSL server
- will wait for completion of the SSL shutdown procedure
- before aborting the connection. Default is 30s.
-
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -377,8 +371,7 @@ class AbstractEventLoop:
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
@@ -390,15 +383,13 @@ class AbstractEventLoop:
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
raise NotImplementedError
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.
@@ -420,9 +411,6 @@ class AbstractEventLoop:
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).
- ssl_shutdown_timeout is the time in seconds that an SSL server
- will wait for the SSL shutdown to finish (defaults to 30s).
-
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -433,8 +421,7 @@ class AbstractEventLoop:
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
"""Handle an accepted connection.
This is used by servers that accept connections outside of
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 10852af..45c11ee 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -642,13 +642,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -814,8 +812,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
def loop(f=None):
try:
@@ -829,8 +826,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
else:
self._make_socket_transport(
conn, protocol,
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 63ab15f..59cb6b1 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -70,15 +70,11 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
- ):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_protocol = sslproto.SSLProtocol(
- self, protocol, sslcontext, waiter,
- server_side, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout
- )
+ self, protocol, sslcontext, waiter,
+ server_side, server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -150,17 +146,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
- ssl_handshake_timeout, ssl_shutdown_timeout)
+ ssl_handshake_timeout)
def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
@@ -191,22 +185,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
- backlog, ssl_handshake_timeout,
- ssl_shutdown_timeout)
+ backlog, ssl_handshake_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
- ssl_handshake_timeout, ssl_shutdown_timeout)
+ ssl_handshake_timeout)
self.create_task(accept)
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
- ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
protocol = None
transport = None
try:
@@ -216,8 +208,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index 79734ab..cad25b2 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -1,5 +1,4 @@
import collections
-import enum
import warnings
try:
import ssl
@@ -7,38 +6,10 @@ except ImportError: # pragma: no cover
ssl = None
from . import constants
-from . import exceptions
from . import protocols
from . import transports
from .log import logger
-if ssl is not None:
- SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
-
-
-class SSLProtocolState(enum.Enum):
- UNWRAPPED = "UNWRAPPED"
- DO_HANDSHAKE = "DO_HANDSHAKE"
- WRAPPED = "WRAPPED"
- FLUSHING = "FLUSHING"
- SHUTDOWN = "SHUTDOWN"
-
-
-class AppProtocolState(enum.Enum):
- # This tracks the state of app protocol (https://git.io/fj59P):
- #
- # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
- #
- # * cm: connection_made()
- # * dr: data_received()
- # * er: eof_received()
- # * cl: connection_lost()
-
- STATE_INIT = "STATE_INIT"
- STATE_CON_MADE = "STATE_CON_MADE"
- STATE_EOF = "STATE_EOF"
- STATE_CON_LOST = "STATE_CON_LOST"
-
def _create_transport_context(server_side, server_hostname):
if server_side:
@@ -54,35 +25,269 @@ def _create_transport_context(server_side, server_hostname):
return sslcontext
-def add_flowcontrol_defaults(high, low, kb):
- if high is None:
- if low is None:
- hi = kb * 1024
- else:
- lo = low
- hi = 4 * lo
- else:
- hi = high
- if low is None:
- lo = hi // 4
- else:
- lo = low
+# States of an _SSLPipe.
+_UNWRAPPED = "UNWRAPPED"
+_DO_HANDSHAKE = "DO_HANDSHAKE"
+_WRAPPED = "WRAPPED"
+_SHUTDOWN = "SHUTDOWN"
+
+
+class _SSLPipe(object):
+ """An SSL "Pipe".
+
+ An SSL pipe allows you to communicate with an SSL/TLS protocol instance
+ through memory buffers. It can be used to implement a security layer for an
+ existing connection where you don't have access to the connection's file
+ descriptor, or for some reason you don't want to use it.
+
+ An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
+ data is passed through untransformed. In wrapped mode, application level
+ data is encrypted to SSL record level data and vice versa. The SSL record
+ level is the lowest level in the SSL protocol suite and is what travels
+ as-is over the wire.
+
+ An SslPipe initially is in "unwrapped" mode. To start SSL, call
+ do_handshake(). To shutdown SSL again, call unwrap().
+ """
+
+ max_size = 256 * 1024 # Buffer size passed to read()
+
+ def __init__(self, context, server_side, server_hostname=None):
+ """
+ The *context* argument specifies the ssl.SSLContext to use.
+
+ The *server_side* argument indicates whether this is a server side or
+ client side transport.
+
+ The optional *server_hostname* argument can be used to specify the
+ hostname you are connecting to. You may only specify this parameter if
+ the _ssl module supports Server Name Indication (SNI).
+ """
+ self._context = context
+ self._server_side = server_side
+ self._server_hostname = server_hostname
+ self._state = _UNWRAPPED
+ self._incoming = ssl.MemoryBIO()
+ self._outgoing = ssl.MemoryBIO()
+ self._sslobj = None
+ self._need_ssldata = False
+ self._handshake_cb = None
+ self._shutdown_cb = None
+
+ @property
+ def context(self):
+ """The SSL context passed to the constructor."""
+ return self._context
+
+ @property
+ def ssl_object(self):
+ """The internal ssl.SSLObject instance.
+
+ Return None if the pipe is not wrapped.
+ """
+ return self._sslobj
+
+ @property
+ def need_ssldata(self):
+ """Whether more record level data is needed to complete a handshake
+ that is currently in progress."""
+ return self._need_ssldata
+
+ @property
+ def wrapped(self):
+ """
+ Whether a security layer is currently in effect.
+
+ Return False during handshake.
+ """
+ return self._state == _WRAPPED
+
+ def do_handshake(self, callback=None):
+ """Start the SSL handshake.
+
+ Return a list of ssldata. A ssldata element is a list of buffers
+
+ The optional *callback* argument can be used to install a callback that
+ will be called when the handshake is complete. The callback will be
+ called with None if successful, else an exception instance.
+ """
+ if self._state != _UNWRAPPED:
+ raise RuntimeError('handshake in progress or completed')
+ self._sslobj = self._context.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=self._server_side,
+ server_hostname=self._server_hostname)
+ self._state = _DO_HANDSHAKE
+ self._handshake_cb = callback
+ ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
+ assert len(appdata) == 0
+ return ssldata
+
+ def shutdown(self, callback=None):
+ """Start the SSL shutdown sequence.
+
+ Return a list of ssldata. A ssldata element is a list of buffers
- if not hi >= lo >= 0:
- raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
- (hi, lo))
+ The optional *callback* argument can be used to install a callback that
+ will be called when the shutdown is complete. The callback will be
+ called without arguments.
+ """
+ if self._state == _UNWRAPPED:
+ raise RuntimeError('no security layer present')
+ if self._state == _SHUTDOWN:
+ raise RuntimeError('shutdown in progress')
+ assert self._state in (_WRAPPED, _DO_HANDSHAKE)
+ self._state = _SHUTDOWN
+ self._shutdown_cb = callback
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+ return ssldata
+
+ def feed_eof(self):
+ """Send a potentially "ragged" EOF.
+
+ This method will raise an SSL_ERROR_EOF exception if the EOF is
+ unexpected.
+ """
+ self._incoming.write_eof()
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+
+ def feed_ssldata(self, data, only_handshake=False):
+ """Feed SSL record level data into the pipe.
+
+ The data must be a bytes instance. It is OK to send an empty bytes
+ instance. This can be used to get ssldata for a handshake initiated by
+ this endpoint.
+
+ Return a (ssldata, appdata) tuple. The ssldata element is a list of
+ buffers containing SSL data that needs to be sent to the remote SSL.
+
+ The appdata element is a list of buffers containing plaintext data that
+ needs to be forwarded to the application. The appdata list may contain
+ an empty buffer indicating an SSL "close_notify" alert. This alert must
+ be acknowledged by calling shutdown().
+ """
+ if self._state == _UNWRAPPED:
+ # If unwrapped, pass plaintext data straight through.
+ if data:
+ appdata = [data]
+ else:
+ appdata = []
+ return ([], appdata)
+
+ self._need_ssldata = False
+ if data:
+ self._incoming.write(data)
+
+ ssldata = []
+ appdata = []
+ try:
+ if self._state == _DO_HANDSHAKE:
+ # Call do_handshake() until it doesn't raise anymore.
+ self._sslobj.do_handshake()
+ self._state = _WRAPPED
+ if self._handshake_cb:
+ self._handshake_cb(None)
+ if only_handshake:
+ return (ssldata, appdata)
+ # Handshake done: execute the wrapped block
+
+ if self._state == _WRAPPED:
+ # Main state: read data from SSL until close_notify
+ while True:
+ chunk = self._sslobj.read(self.max_size)
+ appdata.append(chunk)
+ if not chunk: # close_notify
+ break
+
+ elif self._state == _SHUTDOWN:
+ # Call shutdown() until it doesn't raise anymore.
+ self._sslobj.unwrap()
+ self._sslobj = None
+ self._state = _UNWRAPPED
+ if self._shutdown_cb:
+ self._shutdown_cb()
+
+ elif self._state == _UNWRAPPED:
+ # Drain possible plaintext data after close_notify.
+ appdata.append(self._incoming.read())
+ except (ssl.SSLError, ssl.CertificateError) as exc:
+ exc_errno = getattr(exc, 'errno', None)
+ if exc_errno not in (
+ ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ if self._state == _DO_HANDSHAKE and self._handshake_cb:
+ self._handshake_cb(exc)
+ raise
+ self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
+
+ # Check for record level data that needs to be sent back.
+ # Happens for the initial handshake and renegotiations.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ return (ssldata, appdata)
+
+ def feed_appdata(self, data, offset=0):
+ """Feed plaintext data into the pipe.
+
+ Return an (ssldata, offset) tuple. The ssldata element is a list of
+ buffers containing record level data that needs to be sent to the
+ remote SSL instance. The offset is the number of plaintext bytes that
+ were processed, which may be less than the length of data.
+
+ NOTE: In case of short writes, this call MUST be retried with the SAME
+ buffer passed into the *data* argument (i.e. the id() must be the
+ same). This is an OpenSSL requirement. A further particularity is that
+ a short write will always have offset == 0, because the _ssl module
+ does not enable partial writes. And even though the offset is zero,
+ there will still be encrypted data in ssldata.
+ """
+ assert 0 <= offset <= len(data)
+ if self._state == _UNWRAPPED:
+ # pass through data in unwrapped mode
+ if offset < len(data):
+ ssldata = [data[offset:]]
+ else:
+ ssldata = []
+ return (ssldata, len(data))
- return hi, lo
+ ssldata = []
+ view = memoryview(data)
+ while True:
+ self._need_ssldata = False
+ try:
+ if offset < len(view):
+ offset += self._sslobj.write(view[offset:])
+ except ssl.SSLError as exc:
+ # It is not allowed to call write() after unwrap() until the
+ # close_notify is acknowledged. We return the condition to the
+ # caller as a short write.
+ exc_errno = getattr(exc, 'errno', None)
+ if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
+ exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
+ if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ raise
+ self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
+
+ # See if there's any record level data back for us.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ if offset == len(view) or self._need_ssldata:
+ break
+ return (ssldata, offset)
class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport):
- _start_tls_compatible = True
_sendfile_compatible = constants._SendfileMode.FALLBACK
def __init__(self, loop, ssl_protocol):
self._loop = loop
+ # SSLProtocol instance
self._ssl_protocol = ssl_protocol
self._closed = False
@@ -110,15 +315,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
self._closed = True
self._ssl_protocol._start_shutdown()
- def __del__(self, _warnings=warnings):
+ def __del__(self, _warn=warnings.warn):
if not self._closed:
- self._closed = True
- _warnings.warn(
- "unclosed transport <asyncio._SSLProtocolTransport "
- "object>", ResourceWarning)
+ _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
+ self.close()
def is_reading(self):
- return not self._ssl_protocol._app_reading_paused
+ tr = self._ssl_protocol._transport
+ if tr is None:
+ raise RuntimeError('SSL transport has not been initialized yet')
+ return tr.is_reading()
def pause_reading(self):
"""Pause the receiving end.
@@ -126,7 +332,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
No data will be passed to the protocol's data_received()
method until resume_reading() is called.
"""
- self._ssl_protocol._pause_reading()
+ self._ssl_protocol._transport.pause_reading()
def resume_reading(self):
"""Resume the receiving end.
@@ -134,7 +340,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
Data received will once again be passed to the protocol's
data_received() method.
"""
- self._ssl_protocol._resume_reading()
+ self._ssl_protocol._transport.resume_reading()
def set_write_buffer_limits(self, high=None, low=None):
"""Set the high- and low-water limits for write flow control.
@@ -155,51 +361,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
reduces opportunities for doing I/O and computation
concurrently.
"""
- self._ssl_protocol._set_write_buffer_limits(high, low)
- self._ssl_protocol._control_app_writing()
-
- def get_write_buffer_limits(self):
- return (self._ssl_protocol._outgoing_low_water,
- self._ssl_protocol._outgoing_high_water)
+ self._ssl_protocol._transport.set_write_buffer_limits(high, low)
def get_write_buffer_size(self):
- """Return the current size of the write buffers."""
- return self._ssl_protocol._get_write_buffer_size()
-
- def set_read_buffer_limits(self, high=None, low=None):
- """Set the high- and low-water limits for read flow control.
-
- These two values control when to call the upstream transport's
- pause_reading() and resume_reading() methods. If specified,
- the low-water limit must be less than or equal to the
- high-water limit. Neither value can be negative.
-
- The defaults are implementation-specific. If only the
- high-water limit is given, the low-water limit defaults to an
- implementation-specific value less than or equal to the
- high-water limit. Setting high to zero forces low to zero as
- well, and causes pause_reading() to be called whenever the
- buffer becomes non-empty. Setting low to zero causes
- resume_reading() to be called only once the buffer is empty.
- Use of zero for either limit is generally sub-optimal as it
- reduces opportunities for doing I/O and computation
- concurrently.
- """
- self._ssl_protocol._set_read_buffer_limits(high, low)
- self._ssl_protocol._control_ssl_reading()
-
- def get_read_buffer_limits(self):
- return (self._ssl_protocol._incoming_low_water,
- self._ssl_protocol._incoming_high_water)
-
- def get_read_buffer_size(self):
- """Return the current size of the read buffer."""
- return self._ssl_protocol._get_read_buffer_size()
+ """Return the current size of the write buffer."""
+ return self._ssl_protocol._transport.get_write_buffer_size()
@property
def _protocol_paused(self):
# Required for sendfile fallback pause_writing/resume_writing logic
- return self._ssl_protocol._app_writing_paused
+ return self._ssl_protocol._transport._protocol_paused
def write(self, data):
"""Write some data bytes to the transport.
@@ -212,22 +383,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
f"got {type(data).__name__}")
if not data:
return
- self._ssl_protocol._write_appdata((data,))
-
- def writelines(self, list_of_data):
- """Write a list (or any iterable) of data bytes to the transport.
-
- The default implementation concatenates the arguments and
- calls write() on the result.
- """
- self._ssl_protocol._write_appdata(list_of_data)
-
- def write_eof(self):
- """Close the write end after flushing buffered data.
-
- This raises :exc:`NotImplementedError` right now.
- """
- raise NotImplementedError
+ self._ssl_protocol._write_appdata(data)
def can_write_eof(self):
"""Return True if this transport supports write_eof(), False if not."""
@@ -240,36 +396,23 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
The protocol's connection_lost() method will (eventually) be
called with None as its argument.
"""
- self._closed = True
self._ssl_protocol._abort()
-
- def _force_close(self, exc):
self._closed = True
- self._ssl_protocol._abort(exc)
- def _test__append_write_backlog(self, data):
- # for test only
- self._ssl_protocol._write_backlog.append(data)
- self._ssl_protocol._write_buffer_size += len(data)
+class SSLProtocol(protocols.Protocol):
+ """SSL protocol.
-class SSLProtocol(protocols.BufferedProtocol):
- max_size = 256 * 1024 # Buffer size passed to read()
-
- _handshake_start_time = None
- _handshake_timeout_handle = None
- _shutdown_timeout_handle = None
+ Implementation of SSL on top of a socket using incoming and outgoing
+ buffers which are ssl.MemoryBIO objects.
+ """
def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
call_connection_made=True,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
if ssl is None:
- raise RuntimeError("stdlib ssl module not available")
-
- self._ssl_buffer = bytearray(self.max_size)
- self._ssl_buffer_view = memoryview(self._ssl_buffer)
+ raise RuntimeError('stdlib ssl module not available')
if ssl_handshake_timeout is None:
ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
@@ -277,12 +420,6 @@ class SSLProtocol(protocols.BufferedProtocol):
raise ValueError(
f"ssl_handshake_timeout should be a positive number, "
f"got {ssl_handshake_timeout}")
- if ssl_shutdown_timeout is None:
- ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
- elif ssl_shutdown_timeout <= 0:
- raise ValueError(
- f"ssl_shutdown_timeout should be a positive number, "
- f"got {ssl_shutdown_timeout}")
if not sslcontext:
sslcontext = _create_transport_context(
@@ -305,54 +442,21 @@ class SSLProtocol(protocols.BufferedProtocol):
self._waiter = waiter
self._loop = loop
self._set_app_protocol(app_protocol)
- self._app_transport = None
- self._app_transport_created = False
+ self._app_transport = _SSLProtocolTransport(self._loop, self)
+ # _SSLPipe instance (None until the connection is made)
+ self._sslpipe = None
+ self._session_established = False
+ self._in_handshake = False
+ self._in_shutdown = False
# transport, ex: SelectorSocketTransport
self._transport = None
+ self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
- self._ssl_shutdown_timeout = ssl_shutdown_timeout
- # SSL and state machine
- self._incoming = ssl.MemoryBIO()
- self._outgoing = ssl.MemoryBIO()
- self._state = SSLProtocolState.UNWRAPPED
- self._conn_lost = 0 # Set when connection_lost called
- if call_connection_made:
- self._app_state = AppProtocolState.STATE_INIT
- else:
- self._app_state = AppProtocolState.STATE_CON_MADE
- self._sslobj = self._sslcontext.wrap_bio(
- self._incoming, self._outgoing,
- server_side=self._server_side,
- server_hostname=self._server_hostname)
-
- # Flow Control
-
- self._ssl_writing_paused = False
-
- self._app_reading_paused = False
-
- self._ssl_reading_paused = False
- self._incoming_high_water = 0
- self._incoming_low_water = 0
- self._set_read_buffer_limits()
- self._eof_received = False
-
- self._app_writing_paused = False
- self._outgoing_high_water = 0
- self._outgoing_low_water = 0
- self._set_write_buffer_limits()
- self._get_app_transport()
def _set_app_protocol(self, app_protocol):
self._app_protocol = app_protocol
- # Make fast hasattr check first
- if (hasattr(app_protocol, 'get_buffer') and
- isinstance(app_protocol, protocols.BufferedProtocol)):
- self._app_protocol_get_buffer = app_protocol.get_buffer
- self._app_protocol_buffer_updated = app_protocol.buffer_updated
- self._app_protocol_is_buffer = True
- else:
- self._app_protocol_is_buffer = False
+ self._app_protocol_is_buffer = \
+ isinstance(app_protocol, protocols.BufferedProtocol)
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
@@ -364,20 +468,15 @@ class SSLProtocol(protocols.BufferedProtocol):
self._waiter.set_result(None)
self._waiter = None
- def _get_app_transport(self):
- if self._app_transport is None:
- if self._app_transport_created:
- raise RuntimeError('Creating _SSLProtocolTransport twice')
- self._app_transport = _SSLProtocolTransport(self._loop, self)
- self._app_transport_created = True
- return self._app_transport
-
def connection_made(self, transport):
"""Called when the low-level connection is made.
Start the SSL handshake.
"""
self._transport = transport
+ self._sslpipe = _SSLPipe(self._sslcontext,
+ self._server_side,
+ self._server_hostname)
self._start_handshake()
def connection_lost(self, exc):
@@ -387,58 +486,72 @@ class SSLProtocol(protocols.BufferedProtocol):
meaning a regular EOF is received or the connection was
aborted or closed).
"""
- self._write_backlog.clear()
- self._outgoing.read()
- self._conn_lost += 1
-
- # Just mark the app transport as closed so that its __dealloc__
- # doesn't complain.
- if self._app_transport is not None:
- self._app_transport._closed = True
-
- if self._state != SSLProtocolState.DO_HANDSHAKE:
- if (
- self._app_state == AppProtocolState.STATE_CON_MADE or
- self._app_state == AppProtocolState.STATE_EOF
- ):
- self._app_state = AppProtocolState.STATE_CON_LOST
- self._loop.call_soon(self._app_protocol.connection_lost, exc)
- self._set_state(SSLProtocolState.UNWRAPPED)
+ if self._session_established:
+ self._session_established = False
+ self._loop.call_soon(self._app_protocol.connection_lost, exc)
+ else:
+ # Most likely an exception occurred while in SSL handshake.
+ # Just mark the app transport as closed so that its __del__
+ # doesn't complain.
+ if self._app_transport is not None:
+ self._app_transport._closed = True
self._transport = None
self._app_transport = None
- self._app_protocol = None
+ if getattr(self, '_handshake_timeout_handle', None):
+ self._handshake_timeout_handle.cancel()
self._wakeup_waiter(exc)
+ self._app_protocol = None
+ self._sslpipe = None
- if self._shutdown_timeout_handle:
- self._shutdown_timeout_handle.cancel()
- self._shutdown_timeout_handle = None
- if self._handshake_timeout_handle:
- self._handshake_timeout_handle.cancel()
- self._handshake_timeout_handle = None
+ def pause_writing(self):
+ """Called when the low-level transport's buffer goes over
+ the high-water mark.
+ """
+ self._app_protocol.pause_writing()
- def get_buffer(self, n):
- want = n
- if want <= 0 or want > self.max_size:
- want = self.max_size
- if len(self._ssl_buffer) < want:
- self._ssl_buffer = bytearray(want)
- self._ssl_buffer_view = memoryview(self._ssl_buffer)
- return self._ssl_buffer_view
+ def resume_writing(self):
+ """Called when the low-level transport's buffer drains below
+ the low-water mark.
+ """
+ self._app_protocol.resume_writing()
- def buffer_updated(self, nbytes):
- self._incoming.write(self._ssl_buffer_view[:nbytes])
+ def data_received(self, data):
+ """Called when some SSL data is received.
- if self._state == SSLProtocolState.DO_HANDSHAKE:
- self._do_handshake()
+ The argument is a bytes object.
+ """
+ if self._sslpipe is None:
+ # transport closing, sslpipe is destroyed
+ return
- elif self._state == SSLProtocolState.WRAPPED:
- self._do_read()
+ try:
+ ssldata, appdata = self._sslpipe.feed_ssldata(data)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except BaseException as e:
+ self._fatal_error(e, 'SSL error in data received')
+ return
- elif self._state == SSLProtocolState.FLUSHING:
- self._do_flush()
+ for chunk in ssldata:
+ self._transport.write(chunk)
- elif self._state == SSLProtocolState.SHUTDOWN:
- self._do_shutdown()
+ for chunk in appdata:
+ if chunk:
+ try:
+ if self._app_protocol_is_buffer:
+ protocols._feed_data_to_buffered_proto(
+ self._app_protocol, chunk)
+ else:
+ self._app_protocol.data_received(chunk)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except BaseException as ex:
+ self._fatal_error(
+ ex, 'application protocol failed to receive SSL data')
+ return
+ else:
+ self._start_shutdown()
+ break
def eof_received(self):
"""Called when the other end of the low-level stream
@@ -448,32 +561,19 @@ class SSLProtocol(protocols.BufferedProtocol):
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
- self._eof_received = True
try:
if self._loop.get_debug():
logger.debug("%r received EOF", self)
- if self._state == SSLProtocolState.DO_HANDSHAKE:
- self._on_handshake_complete(ConnectionResetError)
-
- elif self._state == SSLProtocolState.WRAPPED:
- self._set_state(SSLProtocolState.FLUSHING)
- if self._app_reading_paused:
- return True
- else:
- self._do_flush()
-
- elif self._state == SSLProtocolState.FLUSHING:
- self._do_write()
- self._set_state(SSLProtocolState.SHUTDOWN)
- self._do_shutdown()
+ self._wakeup_waiter(ConnectionResetError)
- elif self._state == SSLProtocolState.SHUTDOWN:
- self._do_shutdown()
-
- except Exception:
+ if not self._in_handshake:
+ keep_open = self._app_protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ finally:
self._transport.close()
- raise
def _get_extra_info(self, name, default=None):
if name in self._extra:
@@ -483,45 +583,19 @@ class SSLProtocol(protocols.BufferedProtocol):
else:
return default
- def _set_state(self, new_state):
- allowed = False
-
- if new_state == SSLProtocolState.UNWRAPPED:
- allowed = True
-
- elif (
- self._state == SSLProtocolState.UNWRAPPED and
- new_state == SSLProtocolState.DO_HANDSHAKE
- ):
- allowed = True
-
- elif (
- self._state == SSLProtocolState.DO_HANDSHAKE and
- new_state == SSLProtocolState.WRAPPED
- ):
- allowed = True
-
- elif (
- self._state == SSLProtocolState.WRAPPED and
- new_state == SSLProtocolState.FLUSHING
- ):
- allowed = True
-
- elif (
- self._state == SSLProtocolState.FLUSHING and
- new_state == SSLProtocolState.SHUTDOWN
- ):
- allowed = True
-
- if allowed:
- self._state = new_state
-
+ def _start_shutdown(self):
+ if self._in_shutdown:
+ return
+ if self._in_handshake:
+ self._abort()
else:
- raise RuntimeError(
- 'cannot switch state from {} to {}'.format(
- self._state, new_state))
+ self._in_shutdown = True
+ self._write_appdata(b'')
- # Handshake flow
+ def _write_appdata(self, data):
+ self._write_backlog.append((data, 0))
+ self._write_buffer_size += len(data)
+ self._process_write_backlog()
def _start_handshake(self):
if self._loop.get_debug():
@@ -529,18 +603,17 @@ class SSLProtocol(protocols.BufferedProtocol):
self._handshake_start_time = self._loop.time()
else:
self._handshake_start_time = None
-
- self._set_state(SSLProtocolState.DO_HANDSHAKE)
-
- # start handshake timeout count down
+ self._in_handshake = True
+ # (b'', 1) is a special value in _process_write_backlog() to do
+ # the SSL handshake
+ self._write_backlog.append((b'', 1))
self._handshake_timeout_handle = \
self._loop.call_later(self._ssl_handshake_timeout,
- lambda: self._check_handshake_timeout())
-
- self._do_handshake()
+ self._check_handshake_timeout)
+ self._process_write_backlog()
def _check_handshake_timeout(self):
- if self._state == SSLProtocolState.DO_HANDSHAKE:
+ if self._in_handshake is True:
msg = (
f"SSL handshake is taking longer than "
f"{self._ssl_handshake_timeout} seconds: "
@@ -548,37 +621,24 @@ class SSLProtocol(protocols.BufferedProtocol):
)
self._fatal_error(ConnectionAbortedError(msg))
- def _do_handshake(self):
- try:
- self._sslobj.do_handshake()
- except SSLAgainErrors:
- self._process_outgoing()
- except ssl.SSLError as exc:
- self._on_handshake_complete(exc)
- else:
- self._on_handshake_complete(None)
-
def _on_handshake_complete(self, handshake_exc):
- if self._handshake_timeout_handle is not None:
- self._handshake_timeout_handle.cancel()
- self._handshake_timeout_handle = None
+ self._in_handshake = False
+ self._handshake_timeout_handle.cancel()
- sslobj = self._sslobj
+ sslobj = self._sslpipe.ssl_object
try:
- if handshake_exc is None:
- self._set_state(SSLProtocolState.WRAPPED)
- else:
+ if handshake_exc is not None:
raise handshake_exc
peercert = sslobj.getpeercert()
- except Exception as exc:
- self._set_state(SSLProtocolState.UNWRAPPED)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except BaseException as exc:
if isinstance(exc, ssl.CertificateError):
msg = 'SSL handshake failed on verifying the certificate'
else:
msg = 'SSL handshake failed'
self._fatal_error(exc, msg)
- self._wakeup_waiter(exc)
return
if self._loop.get_debug():
@@ -589,330 +649,85 @@ class SSLProtocol(protocols.BufferedProtocol):
self._extra.update(peercert=peercert,
cipher=sslobj.cipher(),
compression=sslobj.compression(),
- ssl_object=sslobj)
- if self._app_state == AppProtocolState.STATE_INIT:
- self._app_state = AppProtocolState.STATE_CON_MADE
- self._app_protocol.connection_made(self._get_app_transport())
+ ssl_object=sslobj,
+ )
+ if self._call_connection_made:
+ self._app_protocol.connection_made(self._app_transport)
self._wakeup_waiter()
- self._do_read()
-
- # Shutdown flow
-
- def _start_shutdown(self):
- if (
- self._state in (
- SSLProtocolState.FLUSHING,
- SSLProtocolState.SHUTDOWN,
- SSLProtocolState.UNWRAPPED
- )
- ):
- return
- if self._app_transport is not None:
- self._app_transport._closed = True
- if self._state == SSLProtocolState.DO_HANDSHAKE:
- self._abort()
- else:
- self._set_state(SSLProtocolState.FLUSHING)
- self._shutdown_timeout_handle = self._loop.call_later(
- self._ssl_shutdown_timeout,
- lambda: self._check_shutdown_timeout()
- )
- self._do_flush()
-
- def _check_shutdown_timeout(self):
- if (
- self._state in (
- SSLProtocolState.FLUSHING,
- SSLProtocolState.SHUTDOWN
- )
- ):
- self._transport._force_close(
- exceptions.TimeoutError('SSL shutdown timed out'))
-
- def _do_flush(self):
- self._do_read()
- self._set_state(SSLProtocolState.SHUTDOWN)
- self._do_shutdown()
-
- def _do_shutdown(self):
- try:
- if not self._eof_received:
- self._sslobj.unwrap()
- except SSLAgainErrors:
- self._process_outgoing()
- except ssl.SSLError as exc:
- self._on_shutdown_complete(exc)
- else:
- self._process_outgoing()
- self._call_eof_received()
- self._on_shutdown_complete(None)
-
- def _on_shutdown_complete(self, shutdown_exc):
- if self._shutdown_timeout_handle is not None:
- self._shutdown_timeout_handle.cancel()
- self._shutdown_timeout_handle = None
-
- if shutdown_exc:
- self._fatal_error(shutdown_exc)
- else:
- self._loop.call_soon(self._transport.close)
-
- def _abort(self):
- self._set_state(SSLProtocolState.UNWRAPPED)
- if self._transport is not None:
- self._transport.abort()
-
- # Outgoing flow
-
- def _write_appdata(self, list_of_data):
- if (
- self._state in (
- SSLProtocolState.FLUSHING,
- SSLProtocolState.SHUTDOWN,
- SSLProtocolState.UNWRAPPED
- )
- ):
- if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
- logger.warning('SSL connection is closed')
- self._conn_lost += 1
+ self._session_established = True
+ # In case transport.write() was already called. Don't call
+ # immediately _process_write_backlog(), but schedule it:
+ # _on_handshake_complete() can be called indirectly from
+ # _process_write_backlog(), and _process_write_backlog() is not
+ # reentrant.
+ self._loop.call_soon(self._process_write_backlog)
+
+ def _process_write_backlog(self):
+ # Try to make progress on the write backlog.
+ if self._transport is None or self._sslpipe is None:
return
- for data in list_of_data:
- self._write_backlog.append(data)
- self._write_buffer_size += len(data)
-
try:
- if self._state == SSLProtocolState.WRAPPED:
- self._do_write()
-
- except Exception as ex:
- self._fatal_error(ex, 'Fatal error on SSL protocol')
-
- def _do_write(self):
- try:
- while self._write_backlog:
- data = self._write_backlog[0]
- count = self._sslobj.write(data)
- data_len = len(data)
- if count < data_len:
- self._write_backlog[0] = data[count:]
- self._write_buffer_size -= count
+ for i in range(len(self._write_backlog)):
+ data, offset = self._write_backlog[0]
+ if data:
+ ssldata, offset = self._sslpipe.feed_appdata(data, offset)
+ elif offset:
+ ssldata = self._sslpipe.do_handshake(
+ self._on_handshake_complete)
+ offset = 1
else:
- del self._write_backlog[0]
- self._write_buffer_size -= data_len
- except SSLAgainErrors:
- pass
- self._process_outgoing()
-
- def _process_outgoing(self):
- if not self._ssl_writing_paused:
- data = self._outgoing.read()
- if len(data):
- self._transport.write(data)
- self._control_app_writing()
-
- # Incoming flow
-
- def _do_read(self):
- if (
- self._state not in (
- SSLProtocolState.WRAPPED,
- SSLProtocolState.FLUSHING,
- )
- ):
- return
- try:
- if not self._app_reading_paused:
- if self._app_protocol_is_buffer:
- self._do_read__buffered()
- else:
- self._do_read__copied()
- if self._write_backlog:
- self._do_write()
- else:
- self._process_outgoing()
- self._control_ssl_reading()
- except Exception as ex:
- self._fatal_error(ex, 'Fatal error on SSL protocol')
-
- def _do_read__buffered(self):
- offset = 0
- count = 1
-
- buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
- wants = len(buf)
-
- try:
- count = self._sslobj.read(wants, buf)
-
- if count > 0:
- offset = count
- while offset < wants:
- count = self._sslobj.read(wants - offset, buf[offset:])
- if count > 0:
- offset += count
- else:
- break
- else:
- self._loop.call_soon(lambda: self._do_read())
- except SSLAgainErrors:
- pass
- if offset > 0:
- self._app_protocol_buffer_updated(offset)
- if not count:
- # close_notify
- self._call_eof_received()
- self._start_shutdown()
-
- def _do_read__copied(self):
- chunk = b'1'
- zero = True
- one = False
-
- try:
- while True:
- chunk = self._sslobj.read(self.max_size)
- if not chunk:
+ ssldata = self._sslpipe.shutdown(self._finalize)
+ offset = 1
+
+ for chunk in ssldata:
+ self._transport.write(chunk)
+
+ if offset < len(data):
+ self._write_backlog[0] = (data, offset)
+ # A short write means that a write is blocked on a read
+ # We need to enable reading if it is paused!
+ assert self._sslpipe.need_ssldata
+ if self._transport._paused:
+ self._transport.resume_reading()
break
- if zero:
- zero = False
- one = True
- first = chunk
- elif one:
- one = False
- data = [first, chunk]
- else:
- data.append(chunk)
- except SSLAgainErrors:
- pass
- if one:
- self._app_protocol.data_received(first)
- elif not zero:
- self._app_protocol.data_received(b''.join(data))
- if not chunk:
- # close_notify
- self._call_eof_received()
- self._start_shutdown()
-
- def _call_eof_received(self):
- try:
- if self._app_state == AppProtocolState.STATE_CON_MADE:
- self._app_state = AppProtocolState.STATE_EOF
- keep_open = self._app_protocol.eof_received()
- if keep_open:
- logger.warning('returning true from eof_received() '
- 'has no effect when using ssl')
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as ex:
- self._fatal_error(ex, 'Error calling eof_received()')
-
- # Flow control for writes from APP socket
- def _control_app_writing(self):
- size = self._get_write_buffer_size()
- if size >= self._outgoing_high_water and not self._app_writing_paused:
- self._app_writing_paused = True
- try:
- self._app_protocol.pause_writing()
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as exc:
- self._loop.call_exception_handler({
- 'message': 'protocol.pause_writing() failed',
- 'exception': exc,
- 'transport': self._app_transport,
- 'protocol': self,
- })
- elif size <= self._outgoing_low_water and self._app_writing_paused:
- self._app_writing_paused = False
- try:
- self._app_protocol.resume_writing()
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as exc:
- self._loop.call_exception_handler({
- 'message': 'protocol.resume_writing() failed',
- 'exception': exc,
- 'transport': self._app_transport,
- 'protocol': self,
- })
-
- def _get_write_buffer_size(self):
- return self._outgoing.pending + self._write_buffer_size
-
- def _set_write_buffer_limits(self, high=None, low=None):
- high, low = add_flowcontrol_defaults(
- high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
- self._outgoing_high_water = high
- self._outgoing_low_water = low
-
- # Flow control for reads to APP socket
-
- def _pause_reading(self):
- self._app_reading_paused = True
-
- def _resume_reading(self):
- if self._app_reading_paused:
- self._app_reading_paused = False
-
- def resume():
- if self._state == SSLProtocolState.WRAPPED:
- self._do_read()
- elif self._state == SSLProtocolState.FLUSHING:
- self._do_flush()
- elif self._state == SSLProtocolState.SHUTDOWN:
- self._do_shutdown()
- self._loop.call_soon(resume)
-
- # Flow control for reads from SSL socket
-
- def _control_ssl_reading(self):
- size = self._get_read_buffer_size()
- if size >= self._incoming_high_water and not self._ssl_reading_paused:
- self._ssl_reading_paused = True
- self._transport.pause_reading()
- elif size <= self._incoming_low_water and self._ssl_reading_paused:
- self._ssl_reading_paused = False
- self._transport.resume_reading()
-
- def _set_read_buffer_limits(self, high=None, low=None):
- high, low = add_flowcontrol_defaults(
- high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
- self._incoming_high_water = high
- self._incoming_low_water = low
-
- def _get_read_buffer_size(self):
- return self._incoming.pending
-
- # Flow control for writes to SSL socket
-
- def pause_writing(self):
- """Called when the low-level transport's buffer goes over
- the high-water mark.
- """
- assert not self._ssl_writing_paused
- self._ssl_writing_paused = True
-
- def resume_writing(self):
- """Called when the low-level transport's buffer drains below
- the low-water mark.
- """
- assert self._ssl_writing_paused
- self._ssl_writing_paused = False
- self._process_outgoing()
+ # An entire chunk from the backlog was processed. We can
+ # delete it and reduce the outstanding buffer size.
+ del self._write_backlog[0]
+ self._write_buffer_size -= len(data)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except BaseException as exc:
+ if self._in_handshake:
+ # Exceptions will be re-raised in _on_handshake_complete.
+ self._on_handshake_complete(exc)
+ else:
+ self._fatal_error(exc, 'Fatal error on SSL transport')
def _fatal_error(self, exc, message='Fatal error on transport'):
- if self._transport:
- self._transport._force_close(exc)
-
if isinstance(exc, OSError):
if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True)
- elif not isinstance(exc, exceptions.CancelledError):
+ else:
self._loop.call_exception_handler({
'message': message,
'exception': exc,
'transport': self._transport,
'protocol': self,
})
+ if self._transport:
+ self._transport._force_close(exc)
+
+ def _finalize(self):
+ self._sslpipe = None
+
+ if self._transport is not None:
+ self._transport.close()
+
+ def _abort(self):
+ try:
+ if self._transport is not None:
+ self._transport.abort()
+ finally:
+ self._finalize()
diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py
index 181e188..a55b3a3 100644
--- a/Lib/asyncio/unix_events.py
+++ b/Lib/asyncio/unix_events.py
@@ -229,8 +229,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
- ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_handshake_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
@@ -242,9 +241,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if ssl_handshake_timeout is not None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
- if ssl_shutdown_timeout is not None:
- raise ValueError(
- 'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None:
if sock is not None:
@@ -271,15 +267,13 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout)
return transport, protocol
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None,
start_serving=True):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
@@ -288,10 +282,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
- if ssl_shutdown_timeout is not None and not ssl:
- raise ValueError(
- 'ssl_shutdown_timeout is only meaningful with ssl')
-
if path is not None:
if sock is not None:
raise ValueError(
@@ -338,8 +328,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory,
- ssl, backlog, ssl_handshake_timeout,
- ssl_shutdown_timeout)
+ ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index be5ea1e..5691d42 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1437,51 +1437,44 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
- shutdown_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
@@ -1888,7 +1881,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
- MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
+ MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self):
with self.assertWarns(DeprecationWarning):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index 349e4f2..1613c75 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -70,6 +70,44 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
close_transport(transport)
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_make_ssl_transport(self):
+ m = mock.Mock()
+ self.loop._add_reader = mock.Mock()
+ self.loop._add_reader._is_coroutine = False
+ self.loop._add_writer = mock.Mock()
+ self.loop._remove_reader = mock.Mock()
+ self.loop._remove_writer = mock.Mock()
+ waiter = self.loop.create_future()
+ with test_utils.disable_logger():
+ transport = self.loop._make_ssl_transport(
+ m, asyncio.Protocol(), m, waiter)
+
+ with self.assertRaisesRegex(RuntimeError,
+ r'SSL transport.*not.*initialized'):
+ transport.is_reading()
+
+ # execute the handshake while the logger is disabled
+ # to ignore SSL handshake failure
+ test_utils.run_briefly(self.loop)
+
+ self.assertTrue(transport.is_reading())
+ transport.pause_reading()
+ transport.pause_reading()
+ self.assertFalse(transport.is_reading())
+ transport.resume_reading()
+ transport.resume_reading()
+ self.assertTrue(transport.is_reading())
+
+ # Sanity check
+ class_name = transport.__class__.__name__
+ self.assertIn("ssl", class_name.lower())
+ self.assertIn("transport", class_name.lower())
+
+ transport.close()
+ # execute pending callbacks to close the socket transport
+ test_utils.run_briefly(self.loop)
+
@mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
deleted file mode 100644
index 9cdd281..0000000
--- a/Lib/test/test_asyncio/test_ssl.py
+++ /dev/null
@@ -1,1723 +0,0 @@
-import asyncio
-import asyncio.sslproto
-import contextlib
-import gc
-import logging
-import select
-import socket
-import tempfile
-import threading
-import time
-import weakref
-import unittest
-
-try:
- import ssl
-except ImportError:
- ssl = None
-
-from test import support
-from test.test_asyncio import utils as test_utils
-
-
-def tearDownModule():
- asyncio.set_event_loop_policy(None)
-
-
-class MyBaseProto(asyncio.Protocol):
- connected = None
- done = None
-
- def __init__(self, loop=None):
- self.transport = None
- self.state = 'INITIAL'
- self.nbytes = 0
- if loop is not None:
- self.connected = asyncio.Future(loop=loop)
- self.done = asyncio.Future(loop=loop)
-
- def connection_made(self, transport):
- self.transport = transport
- assert self.state == 'INITIAL', self.state
- self.state = 'CONNECTED'
- if self.connected:
- self.connected.set_result(None)
-
- def data_received(self, data):
- assert self.state == 'CONNECTED', self.state
- self.nbytes += len(data)
-
- def eof_received(self):
- assert self.state == 'CONNECTED', self.state
- self.state = 'EOF'
-
- def connection_lost(self, exc):
- assert self.state in ('CONNECTED', 'EOF'), self.state
- self.state = 'CLOSED'
- if self.done:
- self.done.set_result(None)
-
-
-@unittest.skipIf(ssl is None, 'No ssl module')
-class TestSSL(test_utils.TestCase):
-
- PAYLOAD_SIZE = 1024 * 100
- TIMEOUT = 60
-
- def setUp(self):
- super().setUp()
- self.loop = asyncio.new_event_loop()
- self.set_event_loop(self.loop)
- self.addCleanup(self.loop.close)
-
- def tearDown(self):
- # just in case if we have transport close callbacks
- if not self.loop.is_closed():
- test_utils.run_briefly(self.loop)
-
- self.doCleanups()
- support.gc_collect()
- super().tearDown()
-
- def tcp_server(self, server_prog, *,
- family=socket.AF_INET,
- addr=None,
- timeout=5,
- backlog=1,
- max_clients=10):
-
- if addr is None:
- if family == getattr(socket, "AF_UNIX", None):
- with tempfile.NamedTemporaryFile() as tmp:
- addr = tmp.name
- else:
- addr = ('127.0.0.1', 0)
-
- sock = socket.socket(family, socket.SOCK_STREAM)
-
- if timeout is None:
- raise RuntimeError('timeout is required')
- if timeout <= 0:
- raise RuntimeError('only blocking sockets are supported')
- sock.settimeout(timeout)
-
- try:
- sock.bind(addr)
- sock.listen(backlog)
- except OSError as ex:
- sock.close()
- raise ex
-
- return TestThreadedServer(
- self, sock, server_prog, timeout, max_clients)
-
- def tcp_client(self, client_prog,
- family=socket.AF_INET,
- timeout=10):
-
- sock = socket.socket(family, socket.SOCK_STREAM)
-
- if timeout is None:
- raise RuntimeError('timeout is required')
- if timeout <= 0:
- raise RuntimeError('only blocking sockets are supported')
- sock.settimeout(timeout)
-
- return TestThreadedClient(
- self, sock, client_prog, timeout)
-
- def unix_server(self, *args, **kwargs):
- return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
-
- def unix_client(self, *args, **kwargs):
- return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
-
- def _create_server_ssl_context(self, certfile, keyfile=None):
- sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- sslcontext.options |= ssl.OP_NO_SSLv2
- sslcontext.load_cert_chain(certfile, keyfile)
- return sslcontext
-
- def _create_client_ssl_context(self, *, disable_verify=True):
- sslcontext = ssl.create_default_context()
- sslcontext.check_hostname = False
- if disable_verify:
- sslcontext.verify_mode = ssl.CERT_NONE
- return sslcontext
-
- @contextlib.contextmanager
- def _silence_eof_received_warning(self):
- # TODO This warning has to be fixed in asyncio.
- logger = logging.getLogger('asyncio')
- filter = logging.Filter('has no effect when using ssl')
- logger.addFilter(filter)
- try:
- yield
- finally:
- logger.removeFilter(filter)
-
- def _abort_socket_test(self, ex):
- try:
- self.loop.stop()
- finally:
- self.fail(ex)
-
- def new_loop(self):
- return asyncio.new_event_loop()
-
- def new_policy(self):
- return asyncio.DefaultEventLoopPolicy()
-
- async def wait_closed(self, obj):
- if not isinstance(obj, asyncio.StreamWriter):
- return
- try:
- await obj.wait_closed()
- except (BrokenPipeError, ConnectionError):
- pass
-
- def test_create_server_ssl_1(self):
- CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
- TIMEOUT = 60.0 # timeout for this test
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
-
- clients = []
-
- async def handle_client(reader, writer):
- nonlocal CNT
-
- data = await reader.readexactly(len(A_DATA))
- self.assertEqual(data, A_DATA)
- writer.write(b'OK')
-
- data = await reader.readexactly(len(B_DATA))
- self.assertEqual(data, B_DATA)
- writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
-
- await writer.drain()
- writer.close()
-
- CNT += 1
-
- async def test_client(addr):
- fut = asyncio.Future()
-
- def prog(sock):
- try:
- sock.starttls(client_sslctx)
- sock.connect(addr)
- sock.send(A_DATA)
-
- data = sock.recv_all(2)
- self.assertEqual(data, b'OK')
-
- sock.send(B_DATA)
- data = sock.recv_all(4)
- self.assertEqual(data, b'SPAM')
-
- sock.close()
-
- except Exception as ex:
- self.loop.call_soon_threadsafe(fut.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(fut.set_result, None)
-
- client = self.tcp_client(prog)
- client.start()
- clients.append(client)
-
- await fut
-
- async def start_server():
- extras = {}
- extras = dict(ssl_handshake_timeout=40.0)
-
- srv = await asyncio.start_server(
- handle_client,
- '127.0.0.1', 0,
- family=socket.AF_INET,
- ssl=sslctx,
- **extras)
-
- try:
- srv_socks = srv.sockets
- self.assertTrue(srv_socks)
-
- addr = srv_socks[0].getsockname()
-
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(test_client(addr))
-
- await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
-
- finally:
- self.loop.call_soon(srv.close)
- await srv.wait_closed()
-
- with self._silence_eof_received_warning():
- self.loop.run_until_complete(start_server())
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- for client in clients:
- client.stop()
-
- def test_create_connection_ssl_1(self):
- self.loop.set_exception_handler(None)
-
- CNT = 0
- TOTAL_CNT = 25
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
-
- def server(sock):
- sock.starttls(
- sslctx,
- server_side=True)
-
- data = sock.recv_all(len(A_DATA))
- self.assertEqual(data, A_DATA)
- sock.send(b'OK')
-
- data = sock.recv_all(len(B_DATA))
- self.assertEqual(data, B_DATA)
- sock.send(b'SPAM')
-
- sock.close()
-
- async def client(addr):
- extras = {}
- extras = dict(ssl_handshake_timeout=40.0)
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- **extras)
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- writer.write(B_DATA)
- self.assertEqual(await reader.readexactly(4), b'SPAM')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
-
- async def client_sock(addr):
- sock = socket.socket()
- sock.connect(addr)
- reader, writer = await asyncio.open_connection(
- sock=sock,
- ssl=client_sslctx,
- server_hostname='')
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- writer.write(B_DATA)
- self.assertEqual(await reader.readexactly(4), b'SPAM')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
- sock.close()
-
- def run(coro):
- nonlocal CNT
- CNT = 0
-
- async def _gather(*tasks):
- # trampoline
- return await asyncio.gather(*tasks)
-
- with self.tcp_server(server,
- max_clients=TOTAL_CNT,
- backlog=TOTAL_CNT) as srv:
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(coro(srv.addr))
-
- self.loop.run_until_complete(_gather(*tasks))
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- with self._silence_eof_received_warning():
- run(client)
-
- with self._silence_eof_received_warning():
- run(client_sock)
-
- def test_create_connection_ssl_slow_handshake(self):
- client_sslctx = self._create_client_ssl_context()
-
- # silence error logger
- self.loop.set_exception_handler(lambda *args: None)
-
- def server(sock):
- try:
- sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- pass
- finally:
- sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=1.0)
- writer.close()
- await self.wait_closed(writer)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaisesRegex(
- ConnectionAbortedError,
- r'SSL handshake.*is taking longer'):
-
- self.loop.run_until_complete(client(srv.addr))
-
- def test_create_connection_ssl_failed_certificate(self):
- # silence error logger
- self.loop.set_exception_handler(lambda *args: None)
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context(disable_verify=False)
-
- def server(sock):
- try:
- sock.starttls(
- sslctx,
- server_side=True)
- sock.connect()
- except (ssl.SSLError, OSError):
- pass
- finally:
- sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=1.0)
- writer.close()
- await self.wait_closed(writer)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(ssl.SSLCertVerificationError):
- self.loop.run_until_complete(client(srv.addr))
-
- def test_ssl_handshake_timeout(self):
- # bpo-29970: Check that a connection is aborted if handshake is not
- # completed in timeout period, instead of remaining open indefinitely
- client_sslctx = test_utils.simple_client_sslcontext()
-
- # silence error logger
- messages = []
- self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
- server_side_aborted = False
-
- def server(sock):
- nonlocal server_side_aborted
- try:
- sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- server_side_aborted = True
- finally:
- sock.close()
-
- async def client(addr):
- await asyncio.wait_for(
- self.loop.create_connection(
- asyncio.Protocol,
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=10.0),
- 0.5)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(asyncio.TimeoutError):
- self.loop.run_until_complete(client(srv.addr))
-
- self.assertTrue(server_side_aborted)
-
- # Python issue #23197: cancelling a handshake must not raise an
- # exception or log an error, even if the handshake failed
- self.assertEqual(messages, [])
-
- def test_ssl_handshake_connection_lost(self):
- # #246: make sure that no connection_lost() is called before
- # connection_made() is called first
-
- client_sslctx = test_utils.simple_client_sslcontext()
-
- # silence error logger
- self.loop.set_exception_handler(lambda loop, ctx: None)
-
- connection_made_called = False
- connection_lost_called = False
-
- def server(sock):
- sock.recv(1024)
- # break the connection during handshake
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def connection_made(self, transport):
- nonlocal connection_made_called
- connection_made_called = True
-
- def connection_lost(self, exc):
- nonlocal connection_lost_called
- connection_lost_called = True
-
- async def client(addr):
- await self.loop.create_connection(
- ClientProto,
- *addr,
- ssl=client_sslctx,
- server_hostname=''),
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(ConnectionResetError):
- self.loop.run_until_complete(client(srv.addr))
-
- if connection_lost_called:
- if connection_made_called:
- self.fail("unexpected call to connection_lost()")
- else:
- self.fail("unexpected call to connection_lost() without"
- "calling connection_made()")
- elif connection_made_called:
- self.fail("unexpected call to connection_made()")
-
- def test_ssl_connect_accepted_socket(self):
- proto = ssl.PROTOCOL_TLS_SERVER
- server_context = ssl.SSLContext(proto)
- server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
- if hasattr(server_context, 'check_hostname'):
- server_context.check_hostname = False
- server_context.verify_mode = ssl.CERT_NONE
-
- client_context = ssl.SSLContext(proto)
- if hasattr(server_context, 'check_hostname'):
- client_context.check_hostname = False
- client_context.verify_mode = ssl.CERT_NONE
-
- def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
- loop = self.loop
-
- class MyProto(MyBaseProto):
-
- def connection_lost(self, exc):
- super().connection_lost(exc)
- loop.call_soon(loop.stop)
-
- def data_received(self, data):
- super().data_received(data)
- self.transport.write(expected_response)
-
- lsock = socket.socket(socket.AF_INET)
- lsock.bind(('127.0.0.1', 0))
- lsock.listen(1)
- addr = lsock.getsockname()
-
- message = b'test data'
- response = None
- expected_response = b'roger'
-
- def client():
- nonlocal response
- try:
- csock = socket.socket(socket.AF_INET)
- if client_ssl is not None:
- csock = client_ssl.wrap_socket(csock)
- csock.connect(addr)
- csock.sendall(message)
- response = csock.recv(99)
- csock.close()
- except Exception as exc:
- print(
- "Failure in client thread in test_connect_accepted_socket",
- exc)
-
- thread = threading.Thread(target=client, daemon=True)
- thread.start()
-
- conn, _ = lsock.accept()
- proto = MyProto(loop=loop)
- proto.loop = loop
-
- extras = {}
- if server_ssl:
- extras = dict(ssl_handshake_timeout=10.0)
-
- f = loop.create_task(
- loop.connect_accepted_socket(
- (lambda: proto), conn, ssl=server_ssl,
- **extras))
- loop.run_forever()
- conn.close()
- lsock.close()
-
- thread.join(1)
- self.assertFalse(thread.is_alive())
- self.assertEqual(proto.state, 'CLOSED')
- self.assertEqual(proto.nbytes, len(message))
- self.assertEqual(response, expected_response)
- tr, _ = f.result()
-
- if server_ssl:
- self.assertIn('SSL', tr.__class__.__name__)
-
- tr.close()
- # let it close
- self.loop.run_until_complete(asyncio.sleep(0.1))
-
- def test_start_tls_client_corrupted_ssl(self):
- self.loop.set_exception_handler(lambda loop, ctx: None)
-
- sslctx = test_utils.simple_server_sslcontext()
- client_sslctx = test_utils.simple_client_sslcontext()
-
- def server(sock):
- orig_sock = sock.dup()
- try:
- sock.starttls(
- sslctx,
- server_side=True)
- sock.sendall(b'A\n')
- sock.recv_all(1)
- orig_sock.send(b'please corrupt the SSL connection')
- except ssl.SSLError:
- pass
- finally:
- sock.close()
- orig_sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
-
- self.assertEqual(await reader.readline(), b'A\n')
- writer.write(b'B')
- with self.assertRaises(ssl.SSLError):
- await reader.readline()
- writer.close()
- try:
- await self.wait_closed(writer)
- except ssl.SSLError:
- pass
- return 'OK'
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- res = self.loop.run_until_complete(client(srv.addr))
-
- self.assertEqual(res, 'OK')
-
- def test_start_tls_client_reg_proto_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
-
- tr.write(HELLO_MSG)
- new_tr = await self.loop.start_tls(tr, proto, client_context)
-
- self.assertEqual(await on_data, b'O')
- new_tr.write(HELLO_MSG)
- await on_eof
-
- new_tr.close()
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- def test_create_connection_memory_leak(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_context = self._create_client_ssl_context()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- # XXX: We assume user stores the transport in protocol
- proto.tr = tr
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr,
- ssl=client_context)
-
- self.assertEqual(await on_data, b'O')
- tr.write(HELLO_MSG)
- await on_eof
-
- tr.close()
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- # No garbage is left for SSL client from loop.create_connection, even
- # if user stores the SSLTransport in corresponding protocol instance
- client_context = weakref.ref(client_context)
- self.assertIsNone(client_context())
-
- def test_start_tls_client_buf_proto_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- client_con_made_calls = 0
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.sendall(b'2')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProtoFirst(asyncio.BufferedProtocol):
- def __init__(self, on_data):
- self.on_data = on_data
- self.buf = bytearray(1)
-
- def connection_made(self, tr):
- nonlocal client_con_made_calls
- client_con_made_calls += 1
-
- def get_buffer(self, sizehint):
- return self.buf
-
- def buffer_updated(self, nsize):
- assert nsize == 1
- self.on_data.set_result(bytes(self.buf[:nsize]))
-
- def eof_received(self):
- pass
-
- class ClientProtoSecond(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(self, tr):
- nonlocal client_con_made_calls
- client_con_made_calls += 1
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data1 = self.loop.create_future()
- on_data2 = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProtoFirst(on_data1), *addr)
-
- tr.write(HELLO_MSG)
- new_tr = await self.loop.start_tls(tr, proto, client_context)
-
- self.assertEqual(await on_data1, b'O')
- new_tr.write(HELLO_MSG)
-
- new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
- self.assertEqual(await on_data2, b'2')
- new_tr.write(HELLO_MSG)
- await on_eof
-
- new_tr.close()
-
- # connection_made() should be called only once -- when
- # we establish connection for the first time. Start TLS
- # doesn't call connection_made() on application protocols.
- self.assertEqual(client_con_made_calls, 1)
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr),
- timeout=self.TIMEOUT))
-
- def test_start_tls_slow_client_cancel(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- client_context = test_utils.simple_client_sslcontext()
- server_waits_on_handshake = self.loop.create_future()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- try:
- self.loop.call_soon_threadsafe(
- server_waits_on_handshake.set_result, None)
- data = sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- pass
- finally:
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
-
- tr.write(HELLO_MSG)
-
- await server_waits_on_handshake
-
- with self.assertRaises(asyncio.TimeoutError):
- await asyncio.wait_for(
- self.loop.start_tls(tr, proto, client_context),
- 0.5)
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- def test_start_tls_server_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- def client(sock, addr):
- sock.settimeout(self.TIMEOUT)
-
- sock.connect(addr)
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(client_context)
- sock.sendall(HELLO_MSG)
-
- sock.unwrap()
- sock.close()
-
- class ServerProto(asyncio.Protocol):
- def __init__(self, on_con, on_eof, on_con_lost):
- self.on_con = on_con
- self.on_eof = on_eof
- self.on_con_lost = on_con_lost
- self.data = b''
-
- def connection_made(self, tr):
- self.on_con.set_result(tr)
-
- def data_received(self, data):
- self.data += data
-
- def eof_received(self):
- self.on_eof.set_result(1)
-
- def connection_lost(self, exc):
- if exc is None:
- self.on_con_lost.set_result(None)
- else:
- self.on_con_lost.set_exception(exc)
-
- async def main(proto, on_con, on_eof, on_con_lost):
- tr = await on_con
- tr.write(HELLO_MSG)
-
- self.assertEqual(proto.data, b'')
-
- new_tr = await self.loop.start_tls(
- tr, proto, server_context,
- server_side=True,
- ssl_handshake_timeout=self.TIMEOUT)
-
- await on_eof
- await on_con_lost
- self.assertEqual(proto.data, HELLO_MSG)
- new_tr.close()
-
- async def run_main():
- on_con = self.loop.create_future()
- on_eof = self.loop.create_future()
- on_con_lost = self.loop.create_future()
- proto = ServerProto(on_con, on_eof, on_con_lost)
-
- server = await self.loop.create_server(
- lambda: proto, '127.0.0.1', 0)
- addr = server.sockets[0].getsockname()
-
- with self.tcp_client(lambda sock: client(sock, addr),
- timeout=self.TIMEOUT):
- await asyncio.wait_for(
- main(proto, on_con, on_eof, on_con_lost),
- timeout=self.TIMEOUT)
-
- server.close()
- await server.wait_closed()
-
- self.loop.run_until_complete(run_main())
-
- def test_create_server_ssl_over_ssl(self):
- CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
- TIMEOUT = 10.0 # timeout for this test
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx_1 = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx_1 = self._create_client_ssl_context()
- sslctx_2 = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx_2 = self._create_client_ssl_context()
-
- clients = []
-
- async def handle_client(reader, writer):
- nonlocal CNT
-
- data = await reader.readexactly(len(A_DATA))
- self.assertEqual(data, A_DATA)
- writer.write(b'OK')
-
- data = await reader.readexactly(len(B_DATA))
- self.assertEqual(data, B_DATA)
- writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
-
- await writer.drain()
- writer.close()
-
- CNT += 1
-
- class ServerProtocol(asyncio.StreamReaderProtocol):
- def connection_made(self, transport):
- super_ = super()
- transport.pause_reading()
- fut = self._loop.create_task(self._loop.start_tls(
- transport, self, sslctx_2, server_side=True))
-
- def cb(_):
- try:
- tr = fut.result()
- except Exception as ex:
- super_.connection_lost(ex)
- else:
- super_.connection_made(tr)
- fut.add_done_callback(cb)
-
- def server_protocol_factory():
- reader = asyncio.StreamReader()
- protocol = ServerProtocol(reader, handle_client)
- return protocol
-
- async def test_client(addr):
- fut = asyncio.Future()
-
- def prog(sock):
- try:
- sock.connect(addr)
- sock.starttls(client_sslctx_1)
-
- # because wrap_socket() doesn't work correctly on
- # SSLSocket, we have to do the 2nd level SSL manually
- incoming = ssl.MemoryBIO()
- outgoing = ssl.MemoryBIO()
- sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
-
- def do(func, *args):
- while True:
- try:
- rv = func(*args)
- break
- except ssl.SSLWantReadError:
- if outgoing.pending:
- sock.send(outgoing.read())
- incoming.write(sock.recv(65536))
- if outgoing.pending:
- sock.send(outgoing.read())
- return rv
-
- do(sslobj.do_handshake)
-
- do(sslobj.write, A_DATA)
- data = do(sslobj.read, 2)
- self.assertEqual(data, b'OK')
-
- do(sslobj.write, B_DATA)
- data = b''
- while True:
- chunk = do(sslobj.read, 4)
- if not chunk:
- break
- data += chunk
- self.assertEqual(data, b'SPAM')
-
- do(sslobj.unwrap)
- sock.close()
-
- except Exception as ex:
- self.loop.call_soon_threadsafe(fut.set_exception, ex)
- sock.close()
- else:
- self.loop.call_soon_threadsafe(fut.set_result, None)
-
- client = self.tcp_client(prog)
- client.start()
- clients.append(client)
-
- await fut
-
- async def start_server():
- extras = {}
-
- srv = await self.loop.create_server(
- server_protocol_factory,
- '127.0.0.1', 0,
- family=socket.AF_INET,
- ssl=sslctx_1,
- **extras)
-
- try:
- srv_socks = srv.sockets
- self.assertTrue(srv_socks)
-
- addr = srv_socks[0].getsockname()
-
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(test_client(addr))
-
- await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
-
- finally:
- self.loop.call_soon(srv.close)
- await srv.wait_closed()
-
- with self._silence_eof_received_warning():
- self.loop.run_until_complete(start_server())
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- for client in clients:
- client.stop()
-
- def test_shutdown_cleanly(self):
- CNT = 0
- TOTAL_CNT = 25
-
- A_DATA = b'A' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx = self._create_client_ssl_context()
-
- def server(sock):
- sock.starttls(
- sslctx,
- server_side=True)
-
- data = sock.recv_all(len(A_DATA))
- self.assertEqual(data, A_DATA)
- sock.send(b'OK')
-
- sock.unwrap()
-
- sock.close()
-
- async def client(addr):
- extras = {}
- extras = dict(ssl_handshake_timeout=10.0)
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- **extras)
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- self.assertEqual(await reader.read(), b'')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
-
- def run(coro):
- nonlocal CNT
- CNT = 0
-
- async def _gather(*tasks):
- return await asyncio.gather(*tasks)
-
- with self.tcp_server(server,
- max_clients=TOTAL_CNT,
- backlog=TOTAL_CNT) as srv:
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(coro(srv.addr))
-
- self.loop.run_until_complete(
- _gather(*tasks))
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- with self._silence_eof_received_warning():
- run(client)
-
- def test_flush_before_shutdown(self):
- CHUNK = 1024 * 128
- SIZE = 32
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx = self._create_client_ssl_context()
- if hasattr(ssl, 'OP_NO_TLSv1_3'):
- client_sslctx.options |= ssl.OP_NO_TLSv1_3
-
- future = None
-
- def server(sock):
- sock.starttls(sslctx, server_side=True)
- self.assertEqual(sock.recv_all(4), b'ping')
- sock.send(b'pong')
- time.sleep(0.5) # hopefully stuck the TCP buffer
- data = sock.recv_all(CHUNK * SIZE)
- self.assertEqual(len(data), CHUNK * SIZE)
- sock.close()
-
- def run(meth):
- def wrapper(sock):
- try:
- meth(sock)
- except Exception as ex:
- self.loop.call_soon_threadsafe(future.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(future.set_result, None)
- return wrapper
-
- async def client(addr):
- nonlocal future
- future = self.loop.create_future()
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
- sslprotocol = writer.transport._ssl_protocol
- writer.write(b'ping')
- data = await reader.readexactly(4)
- self.assertEqual(data, b'pong')
-
- sslprotocol.pause_writing()
- for _ in range(SIZE):
- writer.write(b'x' * CHUNK)
-
- writer.close()
- sslprotocol.resume_writing()
-
- await self.wait_closed(writer)
- try:
- data = await reader.read()
- self.assertEqual(data, b'')
- except ConnectionResetError:
- pass
- await future
-
- with self.tcp_server(run(server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- def test_remote_shutdown_receives_trailing_data(self):
- CHUNK = 1024 * 128
- SIZE = 32
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
- future = None
-
- def server(sock):
- incoming = ssl.MemoryBIO()
- outgoing = ssl.MemoryBIO()
- sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
-
- while True:
- try:
- sslobj.do_handshake()
- except ssl.SSLWantReadError:
- if outgoing.pending:
- sock.send(outgoing.read())
- incoming.write(sock.recv(16384))
- else:
- if outgoing.pending:
- sock.send(outgoing.read())
- break
-
- while True:
- try:
- data = sslobj.read(4)
- except ssl.SSLWantReadError:
- incoming.write(sock.recv(16384))
- else:
- break
-
- self.assertEqual(data, b'ping')
- sslobj.write(b'pong')
- sock.send(outgoing.read())
-
- time.sleep(0.2) # wait for the peer to fill its backlog
-
- # send close_notify but don't wait for response
- with self.assertRaises(ssl.SSLWantReadError):
- sslobj.unwrap()
- sock.send(outgoing.read())
-
- # should receive all data
- data_len = 0
- while True:
- try:
- chunk = len(sslobj.read(16384))
- data_len += chunk
- except ssl.SSLWantReadError:
- incoming.write(sock.recv(16384))
- except ssl.SSLZeroReturnError:
- break
-
- self.assertEqual(data_len, CHUNK * SIZE)
-
- # verify that close_notify is received
- sslobj.unwrap()
-
- sock.close()
-
- def eof_server(sock):
- sock.starttls(sslctx, server_side=True)
- self.assertEqual(sock.recv_all(4), b'ping')
- sock.send(b'pong')
-
- time.sleep(0.2) # wait for the peer to fill its backlog
-
- # send EOF
- sock.shutdown(socket.SHUT_WR)
-
- # should receive all data
- data = sock.recv_all(CHUNK * SIZE)
- self.assertEqual(len(data), CHUNK * SIZE)
-
- sock.close()
-
- async def client(addr):
- nonlocal future
- future = self.loop.create_future()
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
- writer.write(b'ping')
- data = await reader.readexactly(4)
- self.assertEqual(data, b'pong')
-
- # fill write backlog in a hacky way - renegotiation won't help
- for _ in range(SIZE):
- writer.transport._test__append_write_backlog(b'x' * CHUNK)
-
- try:
- data = await reader.read()
- self.assertEqual(data, b'')
- except (BrokenPipeError, ConnectionResetError):
- pass
-
- await future
-
- writer.close()
- await self.wait_closed(writer)
-
- def run(meth):
- def wrapper(sock):
- try:
- meth(sock)
- except Exception as ex:
- self.loop.call_soon_threadsafe(future.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(future.set_result, None)
- return wrapper
-
- with self.tcp_server(run(server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- with self.tcp_server(run(eof_server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- def test_connect_timeout_warning(self):
- s = socket.socket(socket.AF_INET)
- s.bind(('127.0.0.1', 0))
- addr = s.getsockname()
-
- async def test():
- try:
- await asyncio.wait_for(
- self.loop.create_connection(asyncio.Protocol,
- *addr, ssl=True),
- 0.1)
- except (ConnectionRefusedError, asyncio.TimeoutError):
- pass
- else:
- self.fail('TimeoutError is not raised')
-
- with s:
- try:
- with self.assertWarns(ResourceWarning) as cm:
- self.loop.run_until_complete(test())
- gc.collect()
- gc.collect()
- gc.collect()
- except AssertionError as e:
- self.assertEqual(str(e), 'ResourceWarning not triggered')
- else:
- self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
-
- def test_handshake_timeout_handler_leak(self):
- s = socket.socket(socket.AF_INET)
- s.bind(('127.0.0.1', 0))
- s.listen(1)
- addr = s.getsockname()
-
- async def test(ctx):
- try:
- await asyncio.wait_for(
- self.loop.create_connection(asyncio.Protocol, *addr,
- ssl=ctx),
- 0.1)
- except (ConnectionRefusedError, asyncio.TimeoutError):
- pass
- else:
- self.fail('TimeoutError is not raised')
-
- with s:
- ctx = ssl.create_default_context()
- self.loop.run_until_complete(test(ctx))
- ctx = weakref.ref(ctx)
-
- # SSLProtocol should be DECREF to 0
- self.assertIsNone(ctx())
-
- def test_shutdown_timeout_handler_leak(self):
- loop = self.loop
-
- def server(sock):
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- sock = sslctx.wrap_socket(sock, server_side=True)
- sock.recv(32)
- sock.close()
-
- class Protocol(asyncio.Protocol):
- def __init__(self):
- self.fut = asyncio.Future(loop=loop)
-
- def connection_lost(self, exc):
- self.fut.set_result(None)
-
- async def client(addr, ctx):
- tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
- tr.close()
- await pr.fut
-
- with self.tcp_server(server) as srv:
- ctx = self._create_client_ssl_context()
- loop.run_until_complete(client(srv.addr, ctx))
- ctx = weakref.ref(ctx)
-
- # asyncio has no shutdown timeout, but it ends up with a circular
- # reference loop - not ideal (introduces gc glitches), but at least
- # not leaking
- gc.collect()
- gc.collect()
- gc.collect()
-
- # SSLProtocol should be DECREF to 0
- self.assertIsNone(ctx())
-
- def test_shutdown_timeout_handler_not_set(self):
- loop = self.loop
- eof = asyncio.Event()
- extra = None
-
- def server(sock):
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- sock = sslctx.wrap_socket(sock, server_side=True)
- sock.send(b'hello')
- assert sock.recv(1024) == b'world'
- sock.send(b'extra bytes')
- # sending EOF here
- sock.shutdown(socket.SHUT_WR)
- loop.call_soon_threadsafe(eof.set)
- # make sure we have enough time to reproduce the issue
- assert sock.recv(1024) == b''
- sock.close()
-
- class Protocol(asyncio.Protocol):
- def __init__(self):
- self.fut = asyncio.Future(loop=loop)
- self.transport = None
-
- def connection_made(self, transport):
- self.transport = transport
-
- def data_received(self, data):
- if data == b'hello':
- self.transport.write(b'world')
- # pause reading would make incoming data stay in the sslobj
- self.transport.pause_reading()
- else:
- nonlocal extra
- extra = data
-
- def connection_lost(self, exc):
- if exc is None:
- self.fut.set_result(None)
- else:
- self.fut.set_exception(exc)
-
- async def client(addr):
- ctx = self._create_client_ssl_context()
- tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
- await eof.wait()
- tr.resume_reading()
- await pr.fut
- tr.close()
- assert extra == b'extra bytes'
-
- with self.tcp_server(server) as srv:
- loop.run_until_complete(client(srv.addr))
-
-
-###############################################################################
-# Socket Testing Utilities
-###############################################################################
-
-
-class TestSocketWrapper:
-
- def __init__(self, sock):
- self.__sock = sock
-
- def recv_all(self, n):
- buf = b''
- while len(buf) < n:
- data = self.recv(n - len(buf))
- if data == b'':
- raise ConnectionAbortedError
- buf += data
- return buf
-
- def starttls(self, ssl_context, *,
- server_side=False,
- server_hostname=None,
- do_handshake_on_connect=True):
-
- assert isinstance(ssl_context, ssl.SSLContext)
-
- ssl_sock = ssl_context.wrap_socket(
- self.__sock, server_side=server_side,
- server_hostname=server_hostname,
- do_handshake_on_connect=do_handshake_on_connect)
-
- if server_side:
- ssl_sock.do_handshake()
-
- self.__sock.close()
- self.__sock = ssl_sock
-
- def __getattr__(self, name):
- return getattr(self.__sock, name)
-
- def __repr__(self):
- return '<{} {!r}>'.format(type(self).__name__, self.__sock)
-
-
-class SocketThread(threading.Thread):
-
- def stop(self):
- self._active = False
- self.join()
-
- def __enter__(self):
- self.start()
- return self
-
- def __exit__(self, *exc):
- self.stop()
-
-
-class TestThreadedClient(SocketThread):
-
- def __init__(self, test, sock, prog, timeout):
- threading.Thread.__init__(self, None, None, 'test-client')
- self.daemon = True
-
- self._timeout = timeout
- self._sock = sock
- self._active = True
- self._prog = prog
- self._test = test
-
- def run(self):
- try:
- self._prog(TestSocketWrapper(self._sock))
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as ex:
- self._test._abort_socket_test(ex)
-
-
-class TestThreadedServer(SocketThread):
-
- def __init__(self, test, sock, prog, timeout, max_clients):
- threading.Thread.__init__(self, None, None, 'test-server')
- self.daemon = True
-
- self._clients = 0
- self._finished_clients = 0
- self._max_clients = max_clients
- self._timeout = timeout
- self._sock = sock
- self._active = True
-
- self._prog = prog
-
- self._s1, self._s2 = socket.socketpair()
- self._s1.setblocking(False)
-
- self._test = test
-
- def stop(self):
- try:
- if self._s2 and self._s2.fileno() != -1:
- try:
- self._s2.send(b'stop')
- except OSError:
- pass
- finally:
- super().stop()
-
- def run(self):
- try:
- with self._sock:
- self._sock.setblocking(0)
- self._run()
- finally:
- self._s1.close()
- self._s2.close()
-
- def _run(self):
- while self._active:
- if self._clients >= self._max_clients:
- return
-
- r, w, x = select.select(
- [self._sock, self._s1], [], [], self._timeout)
-
- if self._s1 in r:
- return
-
- if self._sock in r:
- try:
- conn, addr = self._sock.accept()
- except BlockingIOError:
- continue
- except socket.timeout:
- if not self._active:
- return
- else:
- raise
- else:
- self._clients += 1
- conn.settimeout(self._timeout)
- try:
- with conn:
- self._handle_client(conn)
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as ex:
- self._active = False
- try:
- raise
- finally:
- self._test._abort_socket_test(ex)
-
- def _handle_client(self, sock):
- self._prog(TestSocketWrapper(sock))
-
- @property
- def addr(self):
- return self._sock.getsockname()
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 79a81bd..e87863e 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -15,6 +15,7 @@ import asyncio
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
+from test import support
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
@@ -43,13 +44,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
- sslobj = mock.Mock()
- # emulate reading decompressed data
- sslobj.read.side_effect = ssl.SSLWantReadError
- if do_handshake is not None:
- sslobj.do_handshake = do_handshake
- ssl_proto._sslobj = sslobj
- ssl_proto.connection_made(transport)
+ sslpipe = mock.Mock()
+ sslpipe.shutdown.return_value = b''
+ if do_handshake:
+ sslpipe.do_handshake.side_effect = do_handshake
+ else:
+ def mock_handshake(callback):
+ return []
+ sslpipe.do_handshake.side_effect = mock_handshake
+ with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
+ ssl_proto.connection_made(transport)
return transport
def test_handshake_timeout_zero(self):
@@ -71,10 +75,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ self.connection_made(ssl_proto)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
@@ -99,10 +100,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ self.connection_made(ssl_proto)
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@@ -112,10 +110,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- transport = self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ transport = self.connection_made(ssl_proto)
test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close()
@@ -148,7 +143,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
transp.close()
# should not raise
- self.assertIsNone(ssl_proto.buffer_updated(5))
+ self.assertIsNone(ssl_proto.data_received(b'data'))
def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()
diff --git a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst
deleted file mode 100644
index e2b5a9e..0000000
--- a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst
+++ /dev/null
@@ -1,2 +0,0 @@
-Reimplement SSL/TLS support in asyncio, borrow the impelementation from
-uvloop library.