summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio')
-rw-r--r--Lib/asyncio/base_events.py7
-rw-r--r--Lib/asyncio/events.py26
-rw-r--r--Lib/asyncio/streams.py39
-rw-r--r--Lib/asyncio/test_utils.py153
-rw-r--r--Lib/asyncio/unix_events.py75
5 files changed, 263 insertions, 37 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 3bbf6b5..b74e936 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ return transport, protocol
+
+ @tasks.coroutine
+ def _create_connection_transport(self, sock, protocol_factory, ssl,
+ server_hostname):
protocol = protocol_factory()
waiter = futures.Future(loop=self)
if ssl:
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index dd9e3fb..7841ad9 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -220,6 +220,32 @@ class AbstractEventLoop:
"""
raise NotImplementedError
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ raise NotImplementedError
+
+ def create_unix_server(self, protocol_factory, path, *,
+ sock=None, backlog=100, ssl=None):
+ """A coroutine which creates a UNIX Domain Socket server.
+
+ The return valud is a Server object, which can be used to stop
+ the service.
+
+ path is a str, representing a file systsem path to bind the
+ server socket to.
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+
+ backlog is the maximum number of queued connections passed to
+ listen() (defaults to 100).
+
+ ssl can be set to an SSLContext to enable SSL over the
+ accepted connections.
+ """
+ raise NotImplementedError
+
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 8fc2147..698c5c6 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -1,9 +1,13 @@
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
- 'open_connection', 'start_server', 'IncompleteReadError',
+ 'open_connection', 'start_server',
+ 'open_unix_connection', 'start_unix_server',
+ 'IncompleteReadError',
]
+import socket
+
from . import events
from . import futures
from . import protocols
@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds))
+if hasattr(socket, 'AF_UNIX'):
+ # UNIX Domain Sockets are supported on this platform
+
+ @tasks.coroutine
+ def open_unix_connection(path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `open_connection` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop)
+ transport, _ = yield from loop.create_unix_connection(
+ lambda: protocol, path, **kwds)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
+
+
+ @tasks.coroutine
+ def start_unix_server(client_connected_cb, path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `start_server` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+
+ def factory():
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, client_connected_cb,
+ loop=loop)
+ return protocol
+
+ return (yield from loop.create_unix_server(factory, path, **kwds))
+
+
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
index deab7c3..de2916b 100644
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -4,12 +4,18 @@ import collections
import contextlib
import io
import os
+import socket
+import socketserver
import sys
+import tempfile
import threading
import time
import unittest
import unittest.mock
+
+from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+
try:
import ssl
except ImportError: # pragma: no cover
@@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever()
-@contextlib.contextmanager
-def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+class SilentWSGIRequestHandler(WSGIRequestHandler):
- class SilentWSGIRequestHandler(WSGIRequestHandler):
- def get_stderr(self):
- return io.StringIO()
+ def get_stderr(self):
+ return io.StringIO()
- def log_message(self, format, *args):
- pass
+ def log_message(self, format, *args):
+ pass
- class SilentWSGIServer(WSGIServer):
- def handle_error(self, request, client_address):
+
+class SilentWSGIServer(WSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+class SSLWSGIServerMixin:
+
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the ssl key and certificate files) differs
+ # between the stdlib and stand-alone asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'ssl_key.pem')
+ certfile = os.path.join(here, 'ssl_cert.pem')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
pass
- class SSLWSGIServer(SilentWSGIServer):
- def finish_request(self, request, client_address):
- # The relative location of our test directory (which
- # contains the ssl key and certificate files) differs
- # between the stdlib and stand-alone asyncio.
- # Prefer our own if we can find it.
- here = os.path.join(os.path.dirname(__file__), '..', 'tests')
- if not os.path.isdir(here):
- here = os.path.join(os.path.dirname(os.__file__),
- 'test', 'test_asyncio')
- keyfile = os.path.join(here, 'ssl_key.pem')
- certfile = os.path.join(here, 'ssl_cert.pem')
- ssock = ssl.wrap_socket(request,
- keyfile=keyfile,
- certfile=certfile,
- server_side=True)
- try:
- self.RequestHandlerClass(ssock, client_address, self)
- ssock.close()
- except OSError:
- # maybe socket has been closed by peer
- pass
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+ pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
- server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
- httpd = make_server(host, port, app,
- server_class, SilentWSGIRequestHandler)
+ server_class = server_ssl_cls if use_ssl else server_cls
+ httpd = server_class(address, SilentWSGIRequestHandler)
+ httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join()
+if hasattr(socket, 'AF_UNIX'):
+
+ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+ def server_bind(self):
+ socketserver.UnixStreamServer.server_bind(self)
+ self.server_name = '127.0.0.1'
+ self.server_port = 80
+
+
+ class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+ def server_bind(self):
+ UnixHTTPServer.server_bind(self)
+ self.setup_environ()
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ # Code in the stdlib expects that get_request
+ # will return a socket and a tuple (host, port).
+ # However, this isn't true for UNIX sockets,
+ # as the second return value will be a path;
+ # hence we return some fake data sufficient
+ # to get the tests going
+ return request, ('127.0.0.1', '')
+
+
+ class SilentUnixWSGIServer(UnixWSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+ class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+ pass
+
+
+ def gen_unix_socket_path():
+ with tempfile.NamedTemporaryFile() as file:
+ return file.name
+
+
+ @contextlib.contextmanager
+ def unix_socket_path():
+ path = gen_unix_socket_path()
+ try:
+ yield path
+ finally:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+ @contextlib.contextmanager
+ def run_test_unix_server(*, use_ssl=False):
+ with unix_socket_path() as path:
+ yield from _run_test_server(address=path, use_ssl=use_ssl,
+ server_cls=SilentUnixWSGIServer,
+ server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+ yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+ server_cls=SilentWSGIServer,
+ server_ssl_cls=SSLWSGIServer)
+
+
def make_test_protocol(base):
dct = {}
for name in dir(base):
@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
+
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)
diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py
index ea79d33..e0d7507 100644
--- a/Lib/asyncio/unix_events.py
+++ b/Lib/asyncio/unix_events.py
@@ -11,6 +11,7 @@ import sys
import threading
+from . import base_events
from . import base_subprocess
from . import constants
from . import events
@@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
- """Unix event loop
+ """Unix event loop.
- Adds signal handling to SelectorEventLoop
+ Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
"""
def __init__(self, selector=None):
@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
+ @tasks.coroutine
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ assert server_hostname is None or isinstance(server_hostname, str)
+ if ssl:
+ if server_hostname is None:
+ raise ValueError(
+ 'you have to pass server_hostname when using ssl')
+ else:
+ if server_hostname is not None:
+ raise ValueError('server_hostname is only meaningful with ssl')
+
+ if path is not None:
+ if sock is not None:
+ raise ValueError(
+ 'path and sock can not be specified at the same time')
+
+ try:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+ sock.setblocking(False)
+ yield from self.sock_connect(sock, path)
+ except OSError:
+ if sock is not None:
+ sock.close()
+ raise
+
+ else:
+ if sock is None:
+ raise ValueError('no path and sock were specified')
+ sock.setblocking(False)
+
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ return transport, protocol
+
+ @tasks.coroutine
+ def create_unix_server(self, protocol_factory, path=None, *,
+ sock=None, backlog=100, ssl=None):
+ if isinstance(ssl, bool):
+ raise TypeError('ssl argument must be an SSLContext or None')
+
+ if path is not None:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ try:
+ sock.bind(path)
+ except OSError as exc:
+ if exc.errno == errno.EADDRINUSE:
+ # Let's improve the error message by adding
+ # with what exact address it occurs.
+ msg = 'Address {!r} is already in use'.format(path)
+ raise OSError(errno.EADDRINUSE, msg) from None
+ else:
+ raise
+ else:
+ if sock is None:
+ raise ValueError(
+ 'path was not specified, and no sock specified')
+
+ if sock.family != socket.AF_UNIX:
+ raise ValueError(
+ 'A UNIX Domain Socket was expected, got {!r}'.format(sock))
+
+ server = base_events.Server(self, [sock])
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock, ssl, server)
+ return server
+
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)