summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/functional.py')
-rw-r--r--Lib/test/test_asyncio/functional.py279
1 files changed, 279 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py
new file mode 100644
index 0000000..5fd174b
--- /dev/null
+++ b/Lib/test/test_asyncio/functional.py
@@ -0,0 +1,279 @@
+import asyncio
+import asyncio.events
+import contextlib
+import os
+import pprint
+import select
+import socket
+import ssl
+import tempfile
+import threading
+
+
+class FunctionalTestCaseMixin:
+
+ def new_loop(self):
+ return asyncio.new_event_loop()
+
+ def run_loop_briefly(self, *, delay=0.01):
+ self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))
+
+ def loop_exception_handler(self, loop, context):
+ self.__unhandled_exceptions.append(context)
+ self.loop.default_exception_handler(context)
+
+ def setUp(self):
+ self.loop = self.new_loop()
+ asyncio.set_event_loop(None)
+
+ self.loop.set_exception_handler(self.loop_exception_handler)
+ self.__unhandled_exceptions = []
+
+ # Disable `_get_running_loop`.
+ self._old_get_running_loop = asyncio.events._get_running_loop
+ asyncio.events._get_running_loop = lambda: None
+
+ def tearDown(self):
+ try:
+ self.loop.close()
+
+ if self.__unhandled_exceptions:
+ print('Unexpected calls to loop.call_exception_handler():')
+ pprint.pprint(self.__unhandled_exceptions)
+ self.fail('unexpected calls to loop.call_exception_handler()')
+
+ finally:
+ asyncio.events._get_running_loop = self._old_get_running_loop
+ asyncio.set_event_loop(None)
+ self.loop = None
+
+ def tcp_server(self, server_prog, *,
+ family=socket.AF_INET,
+ addr=None,
+ timeout=5,
+ backlog=1,
+ max_clients=10):
+
+ if addr is None:
+ if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
+ with tempfile.NamedTemporaryFile() as tmp:
+ addr = tmp.name
+ else:
+ addr = ('127.0.0.1', 0)
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ try:
+ sock.bind(addr)
+ sock.listen(backlog)
+ except OSError as ex:
+ sock.close()
+ raise ex
+
+ return TestThreadedServer(
+ self, sock, server_prog, timeout, max_clients)
+
+ def tcp_client(self, client_prog,
+ family=socket.AF_INET,
+ timeout=10):
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ return TestThreadedClient(
+ self, sock, client_prog, timeout)
+
+ def unix_server(self, *args, **kwargs):
+ if not hasattr(socket, 'AF_UNIX'):
+ raise NotImplementedError
+ return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
+
+ def unix_client(self, *args, **kwargs):
+ if not hasattr(socket, 'AF_UNIX'):
+ raise NotImplementedError
+ return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
+
+ @contextlib.contextmanager
+ def unix_sock_name(self):
+ with tempfile.TemporaryDirectory() as td:
+ fn = os.path.join(td, 'sock')
+ try:
+ yield fn
+ finally:
+ try:
+ os.unlink(fn)
+ except OSError:
+ pass
+
+ def _abort_socket_test(self, ex):
+ try:
+ self.loop.stop()
+ finally:
+ self.fail(ex)
+
+
+##############################################################################
+# Socket Testing Utilities
+##############################################################################
+
+
+class TestSocketWrapper:
+
+ def __init__(self, sock):
+ self.__sock = sock
+
+ def recv_all(self, n):
+ buf = b''
+ while len(buf) < n:
+ data = self.recv(n - len(buf))
+ if data == b'':
+ raise ConnectionAbortedError
+ buf += data
+ return buf
+
+ def start_tls(self, ssl_context, *,
+ server_side=False,
+ server_hostname=None):
+
+ assert isinstance(ssl_context, ssl.SSLContext)
+
+ ssl_sock = ssl_context.wrap_socket(
+ self.__sock, server_side=server_side,
+ server_hostname=server_hostname,
+ do_handshake_on_connect=False)
+
+ ssl_sock.do_handshake()
+
+ self.__sock.close()
+ self.__sock = ssl_sock
+
+ def __getattr__(self, name):
+ return getattr(self.__sock, name)
+
+ def __repr__(self):
+ return '<{} {!r}>'.format(type(self).__name__, self.__sock)
+
+
+class SocketThread(threading.Thread):
+
+ def stop(self):
+ self._active = False
+ self.join()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *exc):
+ self.stop()
+
+
+class TestThreadedClient(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout):
+ threading.Thread.__init__(self, None, None, 'test-client')
+ self.daemon = True
+
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+ self._prog = prog
+ self._test = test
+
+ def run(self):
+ try:
+ self._prog(TestSocketWrapper(self._sock))
+ except Exception as ex:
+ self._test._abort_socket_test(ex)
+
+
+class TestThreadedServer(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout, max_clients):
+ threading.Thread.__init__(self, None, None, 'test-server')
+ self.daemon = True
+
+ self._clients = 0
+ self._finished_clients = 0
+ self._max_clients = max_clients
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+
+ self._prog = prog
+
+ self._s1, self._s2 = socket.socketpair()
+ self._s1.setblocking(False)
+
+ self._test = test
+
+ def stop(self):
+ try:
+ if self._s2 and self._s2.fileno() != -1:
+ try:
+ self._s2.send(b'stop')
+ except OSError:
+ pass
+ finally:
+ super().stop()
+
+ def run(self):
+ try:
+ with self._sock:
+ self._sock.setblocking(0)
+ self._run()
+ finally:
+ self._s1.close()
+ self._s2.close()
+
+ def _run(self):
+ while self._active:
+ if self._clients >= self._max_clients:
+ return
+
+ r, w, x = select.select(
+ [self._sock, self._s1], [], [], self._timeout)
+
+ if self._s1 in r:
+ return
+
+ if self._sock in r:
+ try:
+ conn, addr = self._sock.accept()
+ except BlockingIOError:
+ continue
+ except socket.timeout:
+ if not self._active:
+ return
+ else:
+ raise
+ else:
+ self._clients += 1
+ conn.settimeout(self._timeout)
+ try:
+ with conn:
+ self._handle_client(conn)
+ except Exception as ex:
+ self._active = False
+ try:
+ raise
+ finally:
+ self._test._abort_socket_test(ex)
+
+ def _handle_client(self, sock):
+ self._prog(TestSocketWrapper(sock))
+
+ @property
+ def addr(self):
+ return self._sock.getsockname()