summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/asyncio-stream.rst18
-rw-r--r--Doc/whatsnew/3.11.rst4
-rw-r--r--Lib/asyncio/streams.py21
-rw-r--r--Lib/test/test_asyncio/test_streams.py63
-rw-r--r--Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst3
5 files changed, 109 insertions, 0 deletions
diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst
index ba534f9..72355d3 100644
--- a/Doc/library/asyncio-stream.rst
+++ b/Doc/library/asyncio-stream.rst
@@ -295,6 +295,24 @@ StreamWriter
be resumed. When there is nothing to wait for, the :meth:`drain`
returns immediately.
+ .. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
+ ssl_handshake_timeout=None)
+
+ Upgrade an existing stream-based connection to TLS.
+
+ Parameters:
+
+ * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
+
+ * *server_hostname*: sets or overrides the host name that the target
+ server's certificate will be matched against.
+
+ * *ssl_handshake_timeout* is the time in seconds to wait for the TLS
+ handshake to complete before aborting the connection. ``60.0`` seconds
+ if ``None`` (default).
+
+ .. versionadded:: 3.8
+
.. method:: is_closing()
Return ``True`` if the stream is closed or in the process of
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index dba554c..9f7f6f5 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -246,6 +246,10 @@ asyncio
:meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
(Contributed by Alex Grönholm in :issue:`46805`.)
+* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading
+ existing stream-based connections to TLS. (Contributed by Ian Good in
+ :issue:`34975`.)
+
fractions
---------
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 080d8a6..a568c4e 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -217,6 +217,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return None
return self._stream_reader_wr()
+ def _replace_writer(self, writer):
+ loop = self._loop
+ transport = writer.transport
+ self._stream_writer = writer
+ self._transport = transport
+ self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
def connection_made(self, transport):
if self._reject_connection:
context = {
@@ -371,6 +378,20 @@ class StreamWriter:
await sleep(0)
await self._protocol._drain_helper()
+ async def start_tls(self, sslcontext, *,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ """Upgrade an existing stream-based connection to TLS."""
+ server_side = self._protocol._client_connected_cb is not None
+ protocol = self._protocol
+ await self.drain()
+ new_transport = await self._loop.start_tls( # type: ignore
+ self._transport, protocol, sslcontext,
+ server_side=server_side, server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ self._transport = new_transport
+ protocol._replace_writer(self)
+
class StreamReader:
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 227b227..a7d1789 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -706,6 +706,69 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(messages, [])
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_start_tls(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ async def handle_client(self, client_reader, client_writer):
+ data1 = await client_reader.readline()
+ client_writer.write(data1)
+ await client_writer.drain()
+ assert client_writer.get_extra_info('sslcontext') is None
+ await client_writer.start_tls(
+ test_utils.simple_server_sslcontext())
+ assert client_writer.get_extra_info('sslcontext') is not None
+ data2 = await client_reader.readline()
+ client_writer.write(data2)
+ await client_writer.drain()
+ client_writer.close()
+ await client_writer.wait_closed()
+
+ def start(self):
+ sock = socket.create_server(('127.0.0.1', 0))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client,
+ sock=sock))
+ return sock.getsockname()
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr)
+ writer.write(b"hello world 1!\n")
+ await writer.drain()
+ msgback1 = await reader.readline()
+ assert writer.get_extra_info('sslcontext') is None
+ await writer.start_tls(test_utils.simple_client_sslcontext())
+ assert writer.get_extra_info('sslcontext') is not None
+ writer.write(b"hello world 2!\n")
+ await writer.drain()
+ msgback2 = await reader.readline()
+ writer.close()
+ await writer.wait_closed()
+ return msgback1, msgback2
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server = MyServer(self.loop)
+ addr = server.start()
+ msg1, msg2 = self.loop.run_until_complete(client(addr))
+ server.stop()
+
+ self.assertEqual(messages, [])
+ self.assertEqual(msg1, b"hello world 1!\n")
+ self.assertEqual(msg2, b"hello world 2!\n")
+
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example
diff --git a/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst
new file mode 100644
index 0000000..1576269
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst
@@ -0,0 +1,3 @@
+Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`,
+which upgrades the connection with TLS using the given
+:class:`~ssl.SSLContext`.