diff options
author | Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com> | 2022-02-15 13:04:00 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-15 13:04:00 (GMT) |
commit | 13c10bfb777483c7b02877aab029345a056b809c (patch) | |
tree | 4a94952a81baef1c7ceef4edc5f5d5cc6e33e2e9 /Lib/asyncio/base_events.py | |
parent | 3be1a443ca8e7d4ba85f95b78df5c4122cae9ede (diff) | |
download | cpython-13c10bfb777483c7b02877aab029345a056b809c.zip cpython-13c10bfb777483c7b02877aab029345a056b809c.tar.gz cpython-13c10bfb777483c7b02877aab029345a056b809c.tar.bz2 |
bpo-44011: New asyncio ssl implementation (#31275)
* bpo-44011: New asyncio ssl implementation
Co-Authored-By: Andrew Svetlov <andrew.svetlov@gmail.com>
* fix warning
* fix typo
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r-- | Lib/asyncio/base_events.py | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 56ea7ba..703c8a4 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -269,7 +269,7 @@ class _SendfileFallbackProtocol(protocols.Protocol): class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, - ssl_handshake_timeout): + ssl_handshake_timeout, ssl_shutdown_timeout=None): self._loop = loop self._sockets = sockets self._active_count = 0 @@ -278,6 +278,7 @@ class Server(events.AbstractServer): self._backlog = backlog self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None @@ -309,7 +310,8 @@ class Server(events.AbstractServer): sock.listen(self._backlog) self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, - self, self._backlog, self._ssl_handshake_timeout) + self, self._backlog, self._ssl_handshake_timeout, + self._ssl_shutdown_timeout) def get_loop(self): return self._loop @@ -463,6 +465,7 @@ class BaseEventLoop(events.AbstractEventLoop): *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -965,6 +968,7 @@ class BaseEventLoop(events.AbstractEventLoop): proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. @@ -1000,6 +1004,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: # If using happy eyeballs, default to interleave addresses by family interleave = 1 @@ -1075,7 +1083,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1087,7 +1096,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): sock.setblocking(False) @@ -1098,7 +1108,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -1189,7 +1200,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1212,6 +1224,7 @@ class BaseEventLoop(events.AbstractEventLoop): self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1397,6 +1410,7 @@ class BaseEventLoop(events.AbstractEventLoop): reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """Create a TCP server. @@ -1420,6 +1434,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and ssl is None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1492,7 +1510,8 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) server = Server(self, sockets, protocol_factory, - ssl, backlog, ssl_handshake_timeout) + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' @@ -1506,7 +1525,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1514,9 +1534,14 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket |