diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
commit | 88a5bf0b2e2a55d8418132001a611af9c0419665 (patch) | |
tree | 03841a088e2f8c04c8182999944711f4db039053 /Lib/asyncio/test_utils.py | |
parent | c36e504c53bb20ee6880b78d77aa1378519c3743 (diff) | |
download | cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.zip cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.gz cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.bz2 |
asyncio: Add support for UNIX Domain Sockets.
Diffstat (limited to 'Lib/asyncio/test_utils.py')
-rw-r--r-- | Lib/asyncio/test_utils.py | 153 |
1 files changed, 119 insertions, 34 deletions
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index deab7c3..de2916b 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -4,12 +4,18 @@ import collections import contextlib import io import os +import socket +import socketserver import sys +import tempfile import threading import time import unittest import unittest.mock + +from http.server import HTTPServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + try: import ssl except ImportError: # pragma: no cover @@ -70,42 +76,51 @@ def run_once(loop): loop.run_forever() -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +class SilentWSGIRequestHandler(WSGIRequestHandler): - class SilentWSGIRequestHandler(WSGIRequestHandler): - def get_stderr(self): - return io.StringIO() + def get_stderr(self): + return io.StringIO() - def log_message(self, format, *args): - pass + def log_message(self, format, *args): + pass - class SilentWSGIServer(WSGIServer): - def handle_error(self, request, client_address): + +class SilentWSGIServer(WSGIServer): + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer pass - class SSLWSGIServer(SilentWSGIServer): - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else SilentWSGIServer - httpd = make_server(host, port, app, - server_class, SilentWSGIRequestHandler) + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) httpd.address = httpd.server_address server_thread = threading.Thread(target=httpd.serve_forever) server_thread.start() @@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): server_thread.join() +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + def make_test_protocol(base): dct = {} for name in dir(base): @@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop): def _write_to_self(self): pass + def MockCallback(**kwargs): return unittest.mock.Mock(spec=['__call__'], **kwargs) |