summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/test_sslproto.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_sslproto.py')
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py177
1 files changed, 152 insertions, 25 deletions
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index c534a34..932487a 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -1,8 +1,7 @@
"""Tests for asyncio/sslproto.py."""
-import os
import logging
-import time
+import socket
import unittest
from unittest import mock
try:
@@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase):
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
+ PAYLOAD_SIZE = 1024 * 100
+ TIMEOUT = 60
+
def new_loop(self):
raise NotImplementedError
- def test_start_tls_client_1(self):
- HELLO_MSG = b'1' * 1024 * 1024
+ def test_buf_feed_data(self):
+
+ class Proto(asyncio.BufferedProtocol):
+
+ def __init__(self, bufsize, usemv):
+ self.buf = bytearray(bufsize)
+ self.mv = memoryview(self.buf)
+ self.data = b''
+ self.usemv = usemv
+
+ def get_buffer(self, sizehint):
+ if self.usemv:
+ return self.mv
+ else:
+ return self.buf
+
+ def buffer_updated(self, nsize):
+ if self.usemv:
+ self.data += self.mv[:nsize]
+ else:
+ self.data += self.buf[:nsize]
+
+ for usemv in [False, True]:
+ proto = Proto(1, usemv)
+ sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+ self.assertEqual(proto.data, b'12345')
+
+ proto = Proto(2, usemv)
+ sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+ self.assertEqual(proto.data, b'12345')
+
+ proto = Proto(2, usemv)
+ sslproto._feed_data_to_bufferred_proto(proto, b'1234')
+ self.assertEqual(proto.data, b'1234')
+
+ proto = Proto(4, usemv)
+ sslproto._feed_data_to_bufferred_proto(proto, b'1234')
+ self.assertEqual(proto.data, b'1234')
+
+ proto = Proto(100, usemv)
+ sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+ self.assertEqual(proto.data, b'12345')
+
+ proto = Proto(0, usemv)
+ with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
+ sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+
+ def test_start_tls_client_reg_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
- sock.settimeout(5)
+ sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
@@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.Protocol):
@@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
+ def test_start_tls_client_buf_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.start_tls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.shutdown(socket.SHUT_RDWR)
+ sock.close()
+
+ class ClientProto(asyncio.BufferedProtocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+ self.buf = bytearray(1)
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def get_buffer(self, sizehint):
+ return self.buf
+
+ def buffer_updated(self, nsize):
+ assert nsize == 1
+ self.on_data.set_result(bytes(self.buf[:nsize]))
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5, loop=self.loop)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data, b'O')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ with self.tcp_server(serve) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr),
+ loop=self.loop, timeout=self.TIMEOUT))
+
def test_start_tls_server_1(self):
- HELLO_MSG = b'1' * 1024 * 1024
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
- # TODO: fix TLSv1.3 support
- client_context.options |= ssl.OP_NO_TLSv1_3
def client(sock, addr):
- time.sleep(0.5)
- sock.settimeout(5)
+ sock.settimeout(self.TIMEOUT)
sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG))
@@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.start_tls(client_context)
sock.sendall(HELLO_MSG)
+
+ sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ServerProto(asyncio.Protocol):
- def __init__(self, on_con, on_eof):
+ def __init__(self, on_con, on_eof, on_con_lost):
self.on_con = on_con
self.on_eof = on_eof
+ self.on_con_lost = on_con_lost
self.data = b''
def connection_made(self, tr):
@@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
def eof_received(self):
self.on_eof.set_result(1)
- async def main():
+ def connection_lost(self, exc):
+ if exc is None:
+ self.on_con_lost.set_result(None)
+ else:
+ self.on_con_lost.set_exception(exc)
+
+ async def main(proto, on_con, on_eof, on_con_lost):
tr = await on_con
tr.write(HELLO_MSG)
@@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_side=True)
await on_eof
+ await on_con_lost
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()
- server.close()
- await server.wait_closed()
+ async def run_main():
+ on_con = self.loop.create_future()
+ on_eof = self.loop.create_future()
+ on_con_lost = self.loop.create_future()
+ proto = ServerProto(on_con, on_eof, on_con_lost)
- on_con = self.loop.create_future()
- on_eof = self.loop.create_future()
- proto = ServerProto(on_con, on_eof)
+ server = await self.loop.create_server(
+ lambda: proto, '127.0.0.1', 0)
+ addr = server.sockets[0].getsockname()
- server = self.loop.run_until_complete(
- self.loop.create_server(
- lambda: proto, '127.0.0.1', 0))
- addr = server.sockets[0].getsockname()
+ with self.tcp_client(lambda sock: client(sock, addr)):
+ await asyncio.wait_for(
+ main(proto, on_con, on_eof, on_con_lost),
+ loop=self.loop, timeout=self.TIMEOUT)
- with self.tcp_client(lambda sock: client(sock, addr)):
- self.loop.run_until_complete(
- asyncio.wait_for(main(), loop=self.loop, timeout=10))
+ server.close()
+ await server.wait_closed()
+
+ self.loop.run_until_complete(run_main())
def test_start_tls_wrong_args(self):
async def main():
@@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
-@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458')
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
def new_loop(self):