summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/test_sslproto.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_sslproto.py')
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py47
1 files changed, 36 insertions, 11 deletions
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 1b2f9d2..fa9cbd5 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -302,6 +302,7 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
+ client_con_made_calls = 0
def serve(sock):
sock.settimeout(self.TIMEOUT)
@@ -315,20 +316,21 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
sock.shutdown(socket.SHUT_RDWR)
sock.close()
- class ClientProto(asyncio.BufferedProtocol):
- def __init__(self, on_data, on_eof):
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
self.buf = bytearray(1)
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
def get_buffer(self, sizehint):
return self.buf
@@ -337,27 +339,50 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
- on_data = self.loop.create_future()
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
+ lambda: ClientProtoFirst(on_data1), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
- self.assertEqual(await on_data, b'O')
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),