diff options
Diffstat (limited to 'Lib/test/test_asyncio/functional.py')
-rw-r--r-- | Lib/test/test_asyncio/functional.py | 279 |
1 files changed, 279 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py new file mode 100644 index 0000000..5fd174b --- /dev/null +++ b/Lib/test/test_asyncio/functional.py @@ -0,0 +1,279 @@ +import asyncio +import asyncio.events +import contextlib +import os +import pprint +import select +import socket +import ssl +import tempfile +import threading + + +class FunctionalTestCaseMixin: + + def new_loop(self): + return asyncio.new_event_loop() + + def run_loop_briefly(self, *, delay=0.01): + self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop)) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + self.loop.default_exception_handler(context) + + def setUp(self): + self.loop = self.new_loop() + asyncio.set_event_loop(None) + + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + # Disable `_get_running_loop`. + self._old_get_running_loop = asyncio.events._get_running_loop + asyncio.events._get_running_loop = lambda: None + + def tearDown(self): + try: + self.loop.close() + + if self.__unhandled_exceptions: + print('Unexpected calls to loop.call_exception_handler():') + pprint.pprint(self.__unhandled_exceptions) + self.fail('unexpected calls to loop.call_exception_handler()') + + finally: + asyncio.events._get_running_loop = self._old_get_running_loop + asyncio.set_event_loop(None) + self.loop = None + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=5, + backlog=1, + max_clients=10): + + if addr is None: + if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + try: + sock.bind(addr) + sock.listen(backlog) + except OSError as ex: + sock.close() + raise ex + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=10): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + @contextlib.contextmanager + def unix_sock_name(self): + with tempfile.TemporaryDirectory() as td: + fn = os.path.join(td, 'sock') + try: + yield fn + finally: + try: + os.unlink(fn) + except OSError: + pass + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + +############################################################################## +# Socket Testing Utilities +############################################################################## + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def start_tls(self, ssl_context, *, + server_side=False, + server_hostname=None): + + assert isinstance(ssl_context, ssl.SSLContext) + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=False) + + ssl_sock.do_handshake() + + self.__sock.close() + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except Exception as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(0) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except socket.timeout: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except Exception as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() |