summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/streams.py
diff options
context:
space:
mode:
authorOleg Iarygin <oleg@arhadthedev.net>2022-04-15 12:23:14 (GMT)
committerGitHub <noreply@github.com>2022-04-15 12:23:14 (GMT)
commit6217864fe5f6855f59d608733ce83fd4466e1b8c (patch)
tree3d852fadd0e29891d382ed9f41f161b237b3e703 /Lib/asyncio/streams.py
parentbd26ef5e9e701d2ab3509a49d9351259a3670772 (diff)
downloadcpython-6217864fe5f6855f59d608733ce83fd4466e1b8c.zip
cpython-6217864fe5f6855f59d608733ce83fd4466e1b8c.tar.gz
cpython-6217864fe5f6855f59d608733ce83fd4466e1b8c.tar.bz2
gh-79156: Add start_tls() method to streams API (#91453)
The existing event loop `start_tls()` method is not sufficient for connections using the streams API. The existing StreamReader works because the new transport passes received data to the original protocol. The StreamWriter must then write data to the new transport, and the StreamReaderProtocol must be updated to close the new transport correctly. The new StreamWriter `start_tls()` updates itself and the reader protocol to the new SSL transport. Co-authored-by: Ian Good <icgood@gmail.com>
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r--Lib/asyncio/streams.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 080d8a6..a568c4e 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -217,6 +217,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return None
return self._stream_reader_wr()
+ def _replace_writer(self, writer):
+ loop = self._loop
+ transport = writer.transport
+ self._stream_writer = writer
+ self._transport = transport
+ self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
def connection_made(self, transport):
if self._reject_connection:
context = {
@@ -371,6 +378,20 @@ class StreamWriter:
await sleep(0)
await self._protocol._drain_helper()
+ async def start_tls(self, sslcontext, *,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ """Upgrade an existing stream-based connection to TLS."""
+ server_side = self._protocol._client_connected_cb is not None
+ protocol = self._protocol
+ await self.drain()
+ new_transport = await self._loop.start_tls( # type: ignore
+ self._transport, protocol, sslcontext,
+ server_side=server_side, server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ self._transport = new_transport
+ protocol._replace_writer(self)
+
class StreamReader: