diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
commit | b057c52b0175a31ee8a786a780bd0eef785820ae (patch) | |
tree | 46e53f51064937eb2abdf7904ce7050af14a5322 /Lib/asyncio | |
parent | fd9d374ae9e1f73e74d0e9a342dd03c7a024d570 (diff) | |
download | cpython-b057c52b0175a31ee8a786a780bd0eef785820ae.zip cpython-b057c52b0175a31ee8a786a780bd0eef785820ae.tar.gz cpython-b057c52b0175a31ee8a786a780bd0eef785820ae.tar.bz2 |
asyncio: Add support for UNIX Domain Sockets.
Diffstat (limited to 'Lib/asyncio')
-rw-r--r-- | Lib/asyncio/base_events.py | 7 | ||||
-rw-r--r-- | Lib/asyncio/events.py | 26 | ||||
-rw-r--r-- | Lib/asyncio/streams.py | 39 | ||||
-rw-r--r-- | Lib/asyncio/test_utils.py | 153 | ||||
-rw-r--r-- | Lib/asyncio/unix_events.py | 75 |
5 files changed, 263 insertions, 37 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 3bbf6b5..b74e936 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): protocol = protocol_factory() waiter = futures.Future(loop=self) if ssl: diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index dd9e3fb..7841ad9 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -220,6 +220,32 @@ class AbstractEventLoop: """ raise NotImplementedError + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return valud is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 8fc2147..698c5c6 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,9 +1,13 @@ """Stream-related things.""" __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', 'IncompleteReadError', + 'open_connection', 'start_server', + 'open_unix_connection', 'start_unix_server', + 'IncompleteReadError', ] +import socket + from . import events from . import futures from . import protocols @@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *, return (yield from loop.create_server(factory, host, port, **kwds)) +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @tasks.coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @tasks.coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index deab7c3..de2916b 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -4,12 +4,18 @@ import collections import contextlib import io import os +import socket +import socketserver import sys +import tempfile import threading import time import unittest import unittest.mock + +from http.server import HTTPServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + try: import ssl except ImportError: # pragma: no cover @@ -70,42 +76,51 @@ def run_once(loop): loop.run_forever() -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +class SilentWSGIRequestHandler(WSGIRequestHandler): - class SilentWSGIRequestHandler(WSGIRequestHandler): - def get_stderr(self): - return io.StringIO() + def get_stderr(self): + return io.StringIO() - def log_message(self, format, *args): - pass + def log_message(self, format, *args): + pass - class SilentWSGIServer(WSGIServer): - def handle_error(self, request, client_address): + +class SilentWSGIServer(WSGIServer): + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer pass - class SSLWSGIServer(SilentWSGIServer): - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else SilentWSGIServer - httpd = make_server(host, port, app, - server_class, SilentWSGIRequestHandler) + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) httpd.address = httpd.server_address server_thread = threading.Thread(target=httpd.serve_forever) server_thread.start() @@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): server_thread.join() +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + def make_test_protocol(base): dct = {} for name in dir(base): @@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop): def _write_to_self(self): pass + def MockCallback(**kwargs): return unittest.mock.Mock(spec=['__call__'], **kwargs) diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index ea79d33..e0d7507 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -11,6 +11,7 @@ import sys import threading +from . import base_events from . import base_subprocess from . import constants from . import events @@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): - """Unix event loop + """Unix event loop. - Adds signal handling to SelectorEventLoop + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. """ def __init__(self, selector=None): @@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) + @tasks.coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except OSError: + if sock is not None: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) |