summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/test_utils.py
diff options
context:
space:
mode:
authorYury Selivanov <yselivanov@sprymix.com>2014-02-18 17:15:06 (GMT)
committerYury Selivanov <yselivanov@sprymix.com>2014-02-18 17:15:06 (GMT)
commit88a5bf0b2e2a55d8418132001a611af9c0419665 (patch)
tree03841a088e2f8c04c8182999944711f4db039053 /Lib/asyncio/test_utils.py
parentc36e504c53bb20ee6880b78d77aa1378519c3743 (diff)
downloadcpython-88a5bf0b2e2a55d8418132001a611af9c0419665.zip
cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.gz
cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.bz2
asyncio: Add support for UNIX Domain Sockets.
Diffstat (limited to 'Lib/asyncio/test_utils.py')
-rw-r--r--Lib/asyncio/test_utils.py153
1 files changed, 119 insertions, 34 deletions
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
index deab7c3..de2916b 100644
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -4,12 +4,18 @@ import collections
import contextlib
import io
import os
+import socket
+import socketserver
import sys
+import tempfile
import threading
import time
import unittest
import unittest.mock
+
+from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+
try:
import ssl
except ImportError: # pragma: no cover
@@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever()
-@contextlib.contextmanager
-def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+class SilentWSGIRequestHandler(WSGIRequestHandler):
- class SilentWSGIRequestHandler(WSGIRequestHandler):
- def get_stderr(self):
- return io.StringIO()
+ def get_stderr(self):
+ return io.StringIO()
- def log_message(self, format, *args):
- pass
+ def log_message(self, format, *args):
+ pass
- class SilentWSGIServer(WSGIServer):
- def handle_error(self, request, client_address):
+
+class SilentWSGIServer(WSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+class SSLWSGIServerMixin:
+
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the ssl key and certificate files) differs
+ # between the stdlib and stand-alone asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'ssl_key.pem')
+ certfile = os.path.join(here, 'ssl_cert.pem')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
pass
- class SSLWSGIServer(SilentWSGIServer):
- def finish_request(self, request, client_address):
- # The relative location of our test directory (which
- # contains the ssl key and certificate files) differs
- # between the stdlib and stand-alone asyncio.
- # Prefer our own if we can find it.
- here = os.path.join(os.path.dirname(__file__), '..', 'tests')
- if not os.path.isdir(here):
- here = os.path.join(os.path.dirname(os.__file__),
- 'test', 'test_asyncio')
- keyfile = os.path.join(here, 'ssl_key.pem')
- certfile = os.path.join(here, 'ssl_cert.pem')
- ssock = ssl.wrap_socket(request,
- keyfile=keyfile,
- certfile=certfile,
- server_side=True)
- try:
- self.RequestHandlerClass(ssock, client_address, self)
- ssock.close()
- except OSError:
- # maybe socket has been closed by peer
- pass
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+ pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
- server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
- httpd = make_server(host, port, app,
- server_class, SilentWSGIRequestHandler)
+ server_class = server_ssl_cls if use_ssl else server_cls
+ httpd = server_class(address, SilentWSGIRequestHandler)
+ httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join()
+if hasattr(socket, 'AF_UNIX'):
+
+ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+ def server_bind(self):
+ socketserver.UnixStreamServer.server_bind(self)
+ self.server_name = '127.0.0.1'
+ self.server_port = 80
+
+
+ class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+ def server_bind(self):
+ UnixHTTPServer.server_bind(self)
+ self.setup_environ()
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ # Code in the stdlib expects that get_request
+ # will return a socket and a tuple (host, port).
+ # However, this isn't true for UNIX sockets,
+ # as the second return value will be a path;
+ # hence we return some fake data sufficient
+ # to get the tests going
+ return request, ('127.0.0.1', '')
+
+
+ class SilentUnixWSGIServer(UnixWSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+ class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+ pass
+
+
+ def gen_unix_socket_path():
+ with tempfile.NamedTemporaryFile() as file:
+ return file.name
+
+
+ @contextlib.contextmanager
+ def unix_socket_path():
+ path = gen_unix_socket_path()
+ try:
+ yield path
+ finally:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+ @contextlib.contextmanager
+ def run_test_unix_server(*, use_ssl=False):
+ with unix_socket_path() as path:
+ yield from _run_test_server(address=path, use_ssl=use_ssl,
+ server_cls=SilentUnixWSGIServer,
+ server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+ yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+ server_cls=SilentWSGIServer,
+ server_ssl_cls=SSLWSGIServer)
+
+
def make_test_protocol(base):
dct = {}
for name in dir(base):
@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
+
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)