summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/utils.py
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2017-12-11 15:04:40 (GMT)
committerGitHub <noreply@github.com>2017-12-11 15:04:40 (GMT)
commit3e9751819ad13a965e8be13c1e5bc5a6811fe6b8 (patch)
treed93da6e327fb548daf2a370b24b1b3479e7f3ed5 /Lib/test/test_asyncio/utils.py
parent9edad3c7011ccab0a66a065933abebf3288cf1a1 (diff)
downloadcpython-3e9751819ad13a965e8be13c1e5bc5a6811fe6b8.zip
cpython-3e9751819ad13a965e8be13c1e5bc5a6811fe6b8.tar.gz
cpython-3e9751819ad13a965e8be13c1e5bc5a6811fe6b8.tar.bz2
bpo-32273: Move asyncio.test_utils to test.test_asyncio (#4785)
Diffstat (limited to 'Lib/test/test_asyncio/utils.py')
-rw-r--r--Lib/test/test_asyncio/utils.py502
1 files changed, 502 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py
new file mode 100644
index 0000000..560db9f
--- /dev/null
+++ b/Lib/test/test_asyncio/utils.py
@@ -0,0 +1,502 @@
+"""Utilities shared by tests."""
+
+import collections
+import contextlib
+import io
+import logging
+import os
+import re
+import selectors
+import socket
+import socketserver
+import sys
+import tempfile
+import threading
+import time
+import unittest
+import weakref
+
+from unittest import mock
+
+from http.server import HTTPServer
+from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
+
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from asyncio import base_events
+from asyncio import events
+from asyncio import futures
+from asyncio import tasks
+from asyncio.log import logger
+from test import support
+
+
+def dummy_ssl_context():
+ if ssl is None:
+ return None
+ else:
+ return ssl.SSLContext(ssl.PROTOCOL_TLS)
+
+
+def run_briefly(loop):
+ async def once():
+ pass
+ gen = once()
+ t = loop.create_task(gen)
+ # Don't log a warning if the task is not done after run_until_complete().
+ # It occurs if the loop is stopped or if a task raises a BaseException.
+ t._log_destroy_pending = False
+ try:
+ loop.run_until_complete(t)
+ finally:
+ gen.close()
+
+
+def run_until(loop, pred, timeout=30):
+ deadline = time.time() + timeout
+ while not pred():
+ if timeout is not None:
+ timeout = deadline - time.time()
+ if timeout <= 0:
+ raise futures.TimeoutError()
+ loop.run_until_complete(tasks.sleep(0.001, loop=loop))
+
+
+def run_once(loop):
+ """Legacy API to run once through the event loop.
+
+ This is the recommended pattern for test code. It will poll the
+ selector once and run all callbacks scheduled in response to I/O
+ events.
+ """
+ loop.call_soon(loop.stop)
+ loop.run_forever()
+
+
+class SilentWSGIRequestHandler(WSGIRequestHandler):
+
+ def get_stderr(self):
+ return io.StringIO()
+
+ def log_message(self, format, *args):
+ pass
+
+
+class SilentWSGIServer(WSGIServer):
+
+ request_timeout = 2
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ request.settimeout(self.request_timeout)
+ return request, client_addr
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+class SSLWSGIServerMixin:
+
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the ssl key and certificate files) differs
+ # between the stdlib and stand-alone asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'ssl_key.pem')
+ certfile = os.path.join(here, 'ssl_cert.pem')
+ context = ssl.SSLContext()
+ context.load_cert_chain(certfile, keyfile)
+
+ ssock = context.wrap_socket(request, server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
+ pass
+
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+ pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
+
+ def app(environ, start_response):
+ status = '200 OK'
+ headers = [('Content-type', 'text/plain')]
+ start_response(status, headers)
+ return [b'Test message']
+
+ # Run the test WSGI server in a separate thread in order not to
+ # interfere with event handling in the main thread
+ server_class = server_ssl_cls if use_ssl else server_cls
+ httpd = server_class(address, SilentWSGIRequestHandler)
+ httpd.set_app(app)
+ httpd.address = httpd.server_address
+ server_thread = threading.Thread(
+ target=lambda: httpd.serve_forever(poll_interval=0.05))
+ server_thread.start()
+ try:
+ yield httpd
+ finally:
+ httpd.shutdown()
+ httpd.server_close()
+ server_thread.join()
+
+
+if hasattr(socket, 'AF_UNIX'):
+
+ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+ def server_bind(self):
+ socketserver.UnixStreamServer.server_bind(self)
+ self.server_name = '127.0.0.1'
+ self.server_port = 80
+
+
+ class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+ request_timeout = 2
+
+ def server_bind(self):
+ UnixHTTPServer.server_bind(self)
+ self.setup_environ()
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ request.settimeout(self.request_timeout)
+ # Code in the stdlib expects that get_request
+ # will return a socket and a tuple (host, port).
+ # However, this isn't true for UNIX sockets,
+ # as the second return value will be a path;
+ # hence we return some fake data sufficient
+ # to get the tests going
+ return request, ('127.0.0.1', '')
+
+
+ class SilentUnixWSGIServer(UnixWSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+ class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+ pass
+
+
+ def gen_unix_socket_path():
+ with tempfile.NamedTemporaryFile() as file:
+ return file.name
+
+
+ @contextlib.contextmanager
+ def unix_socket_path():
+ path = gen_unix_socket_path()
+ try:
+ yield path
+ finally:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+ @contextlib.contextmanager
+ def run_test_unix_server(*, use_ssl=False):
+ with unix_socket_path() as path:
+ yield from _run_test_server(address=path, use_ssl=use_ssl,
+ server_cls=SilentUnixWSGIServer,
+ server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+ yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+ server_cls=SilentWSGIServer,
+ server_ssl_cls=SSLWSGIServer)
+
+
+def make_test_protocol(base):
+ dct = {}
+ for name in dir(base):
+ if name.startswith('__') and name.endswith('__'):
+ # skip magic names
+ continue
+ dct[name] = MockCallback(return_value=None)
+ return type('TestProtocol', (base,) + base.__bases__, dct)()
+
+
+class TestSelector(selectors.BaseSelector):
+
+ def __init__(self):
+ self.keys = {}
+
+ def register(self, fileobj, events, data=None):
+ key = selectors.SelectorKey(fileobj, 0, events, data)
+ self.keys[fileobj] = key
+ return key
+
+ def unregister(self, fileobj):
+ return self.keys.pop(fileobj)
+
+ def select(self, timeout):
+ return []
+
+ def get_map(self):
+ return self.keys
+
+
+class TestLoop(base_events.BaseEventLoop):
+ """Loop for unittests.
+
+ It manages self time directly.
+ If something scheduled to be executed later then
+ on next loop iteration after all ready handlers done
+ generator passed to __init__ is calling.
+
+ Generator should be like this:
+
+ def gen():
+ ...
+ when = yield ...
+ ... = yield time_advance
+
+ Value returned by yield is absolute time of next scheduled handler.
+ Value passed to yield is time advance to move loop's time forward.
+ """
+
+ def __init__(self, gen=None):
+ super().__init__()
+
+ if gen is None:
+ def gen():
+ yield
+ self._check_on_close = False
+ else:
+ self._check_on_close = True
+
+ self._gen = gen()
+ next(self._gen)
+ self._time = 0
+ self._clock_resolution = 1e-9
+ self._timers = []
+ self._selector = TestSelector()
+
+ self.readers = {}
+ self.writers = {}
+ self.reset_counters()
+
+ self._transports = weakref.WeakValueDictionary()
+
+ def time(self):
+ return self._time
+
+ def advance_time(self, advance):
+ """Move test time forward."""
+ if advance:
+ self._time += advance
+
+ def close(self):
+ super().close()
+ if self._check_on_close:
+ try:
+ self._gen.send(0)
+ except StopIteration:
+ pass
+ else: # pragma: no cover
+ raise AssertionError("Time generator is not finished")
+
+ def _add_reader(self, fd, callback, *args):
+ self.readers[fd] = events.Handle(callback, args, self)
+
+ def _remove_reader(self, fd):
+ self.remove_reader_count[fd] += 1
+ if fd in self.readers:
+ del self.readers[fd]
+ return True
+ else:
+ return False
+
+ def assert_reader(self, fd, callback, *args):
+ assert fd in self.readers, 'fd {} is not registered'.format(fd)
+ handle = self.readers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def _add_writer(self, fd, callback, *args):
+ self.writers[fd] = events.Handle(callback, args, self)
+
+ def _remove_writer(self, fd):
+ self.remove_writer_count[fd] += 1
+ if fd in self.writers:
+ del self.writers[fd]
+ return True
+ else:
+ return False
+
+ def assert_writer(self, fd, callback, *args):
+ assert fd in self.writers, 'fd {} is not registered'.format(fd)
+ handle = self.writers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def _ensure_fd_no_transport(self, fd):
+ if not isinstance(fd, int):
+ try:
+ fd = int(fd.fileno())
+ except (AttributeError, TypeError, ValueError):
+ # This code matches selectors._fileobj_to_fd function.
+ raise ValueError("Invalid file object: "
+ "{!r}".format(fd)) from None
+ try:
+ transport = self._transports[fd]
+ except KeyError:
+ pass
+ else:
+ raise RuntimeError(
+ 'File descriptor {!r} is used by transport {!r}'.format(
+ fd, transport))
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._add_reader(fd, callback, *args)
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._remove_reader(fd)
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback.."""
+ self._ensure_fd_no_transport(fd)
+ return self._add_writer(fd, callback, *args)
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._remove_writer(fd)
+
+ def reset_counters(self):
+ self.remove_reader_count = collections.defaultdict(int)
+ self.remove_writer_count = collections.defaultdict(int)
+
+ def _run_once(self):
+ super()._run_once()
+ for when in self._timers:
+ advance = self._gen.send(when)
+ self.advance_time(advance)
+ self._timers = []
+
+ def call_at(self, when, callback, *args):
+ self._timers.append(when)
+ return super().call_at(when, callback, *args)
+
+ def _process_events(self, event_list):
+ return
+
+ def _write_to_self(self):
+ pass
+
+
+def MockCallback(**kwargs):
+ return mock.Mock(spec=['__call__'], **kwargs)
+
+
+class MockPattern(str):
+ """A regex based str with a fuzzy __eq__.
+
+ Use this helper with 'mock.assert_called_with', or anywhere
+ where a regex comparison between strings is needed.
+
+ For instance:
+ mock_call.assert_called_with(MockPattern('spam.*ham'))
+ """
+ def __eq__(self, other):
+ return bool(re.search(str(self), other, re.S))
+
+
+def get_function_source(func):
+ source = events._get_function_source(func)
+ if source is None:
+ raise ValueError("unable to get the source of %r" % (func,))
+ return source
+
+
+class TestCase(unittest.TestCase):
+ @staticmethod
+ def close_loop(loop):
+ executor = loop._default_executor
+ if executor is not None:
+ executor.shutdown(wait=True)
+ loop.close()
+
+ def set_event_loop(self, loop, *, cleanup=True):
+ assert loop is not None
+ # ensure that the event loop is passed explicitly in asyncio
+ events.set_event_loop(None)
+ if cleanup:
+ self.addCleanup(self.close_loop, loop)
+
+ def new_test_loop(self, gen=None):
+ loop = TestLoop(gen)
+ self.set_event_loop(loop)
+ return loop
+
+ def unpatch_get_running_loop(self):
+ events._get_running_loop = self._get_running_loop
+
+ def setUp(self):
+ self._get_running_loop = events._get_running_loop
+ events._get_running_loop = lambda: None
+ self._thread_cleanup = support.threading_setup()
+
+ def tearDown(self):
+ self.unpatch_get_running_loop()
+
+ events.set_event_loop(None)
+
+ # Detect CPython bug #23353: ensure that yield/yield-from is not used
+ # in an except block of a generator
+ self.assertEqual(sys.exc_info(), (None, None, None))
+
+ self.doCleanups()
+ support.threading_cleanup(*self._thread_cleanup)
+ support.reap_children()
+
+
+@contextlib.contextmanager
+def disable_logger():
+ """Context manager to disable asyncio logger.
+
+ For example, it can be used to ignore warnings in debug mode.
+ """
+ old_level = logger.level
+ try:
+ logger.setLevel(logging.CRITICAL+1)
+ yield
+ finally:
+ logger.setLevel(old_level)
+
+
+def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
+ family=socket.AF_INET):
+ """Create a mock of a non-blocking socket."""
+ sock = mock.MagicMock(socket.socket)
+ sock.proto = proto
+ sock.type = type
+ sock.family = family
+ sock.gettimeout.return_value = 0.0
+ return sock