summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2018-05-29 09:02:40 (GMT)
committerAndrew Svetlov <andrew.svetlov@gmail.com>2018-05-29 09:02:40 (GMT)
commit2179022d94937d7b0600b0dc192ca6fa5f53d830 (patch)
tree7d548110d8138728e42d5621a0a38424abd47bcd
parentf295587c45f96b62d24f9a12cef6931b0805f596 (diff)
downloadcpython-2179022d94937d7b0600b0dc192ca6fa5f53d830.zip
cpython-2179022d94937d7b0600b0dc192ca6fa5f53d830.tar.gz
cpython-2179022d94937d7b0600b0dc192ca6fa5f53d830.tar.bz2
bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)
-rw-r--r--Lib/asyncio/sslproto.py11
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py47
-rw-r--r--Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst1
3 files changed, 44 insertions, 15 deletions
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index ab43e93..a6d382e 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -295,7 +295,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
return self._ssl_protocol._get_extra_info(name, default)
def set_protocol(self, protocol):
- self._ssl_protocol._app_protocol = protocol
+ self._ssl_protocol._set_app_protocol(protocol)
def get_protocol(self):
return self._ssl_protocol._app_protocol
@@ -440,9 +440,7 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter
self._loop = loop
- self._app_protocol = app_protocol
- self._app_protocol_is_buffer = \
- isinstance(app_protocol, protocols.BufferedProtocol)
+ self._set_app_protocol(app_protocol)
self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made)
self._sslpipe = None
@@ -454,6 +452,11 @@ class SSLProtocol(protocols.Protocol):
self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
+ def _set_app_protocol(self, app_protocol):
+ self._app_protocol = app_protocol
+ self._app_protocol_is_buffer = \
+ isinstance(app_protocol, protocols.BufferedProtocol)
+
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 1b2f9d2..fa9cbd5 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -302,6 +302,7 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
+ client_con_made_calls = 0
def serve(sock):
sock.settimeout(self.TIMEOUT)
@@ -315,20 +316,21 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
sock.shutdown(socket.SHUT_RDWR)
sock.close()
- class ClientProto(asyncio.BufferedProtocol):
- def __init__(self, on_data, on_eof):
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
self.buf = bytearray(1)
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
def get_buffer(self, sizehint):
return self.buf
@@ -337,27 +339,50 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
- on_data = self.loop.create_future()
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
+ lambda: ClientProtoFirst(on_data1), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
- self.assertEqual(await on_data, b'O')
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
diff --git a/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
new file mode 100644
index 0000000..39e8e61
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
@@ -0,0 +1 @@
+Support protocol type switching in SSLTransport.set_protocol().